# Distribution Focal Loss

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Source: https://github.com/Yuxiang1995/ICDAR2021_MFD
# Define the distribution_focal_loss function
def distribution_focal_loss(pred, label):
    r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
    Qualified and Distributed Bounding Boxes for Dense Object Detection
    <https://arxiv.org/abs/2006.04388>`_.

    Args:
        pred (torch.Tensor): Predicted general distribution of bounding boxes
            (before softmax) with shape (N, n+1), n is the max value of the
            integral set `{0, ..., n}` in paper.
        label (torch.Tensor): Target distance label for bounding boxes with
            shape (N,).

    Returns:
        torch.Tensor: Loss tensor with shape (N,).
    """
    dis_left = label.long()
    dis_right = dis_left + 1
    weight_left = dis_right.float() - label
    weight_right = label - dis_left.float()
    loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \
        + F.cross_entropy(pred, dis_right, reduction='none') * weight_right
    return loss

class DistributionFocalLoss(nn.Module):

    def __init__(self,
                 reduction='mean',
                 loss_weight=1.0):
        super(DistributionFocalLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_cls = self.loss_weight * distribution_focal_loss(
            pred,
            target)
        loss = loss_cls.mean()
        return loss

# Example inputs
N, n = 5, 10  # Assume N samples and max value n for the integral set
pred = torch.randn(N, n+1, requires_grad=True)  # Random predictions
label = torch.rand(N) * n  # Random target labels in the range [0, n]

# Instantiate DistributionFocalLoss and compute loss
distribution_focal_loss_instance = DistributionFocalLoss()
loss_output = distribution_focal_loss_instance(pred, label)

loss_output


tensor(3.5175, grad_fn=<MeanBackward0>)

# Quality Focal Loss

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Source: https://github.com/gau-nernst/centernet-lightning
# Define the QualityFocalLoss class
class QualityFocalLoss(nn.Module):
    '''Quality Focal Loss. Use logits to improve numerical stability. Generalized Focal Loss: https://arxiv.org/abs/2006.04388
    '''
    def __init__(self, beta: float = 2, reduction: str = 'sum'):
        '''Quality Focal Loss. Default values are from the paper

        Args:
            beta: control the scaling/modulating factor to reduce the impact of easy examples
            reduction: either none, sum, or mean 
        '''

        super().__init__()
        assert reduction in ('none', 'sum', 'mean')
        self.beta = beta
        self.reduction = reduction

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
        probs = torch.sigmoid(inputs)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        modulating_factor = torch.pow(torch.abs(targets - probs), self.beta)
        loss = modulating_factor * ce_loss
        
        if self.reduction == 'none':
            return loss
        elif self.reduction == 'sum':
            return torch.sum(loss)
        elif self.reduction == 'mean':
            return loss.mean()  # Adjusted to use mean directly for simplicity

# Example inputs
inputs = torch.randn(5, requires_grad=True)  # Example logits for 5 instances
targets = torch.empty(5).random_(2)  # Binary targets for the same instances

# Instantiate QualityFocalLoss and compute loss
quality_focal_loss = QualityFocalLoss(reduction='mean')  # Using 'mean' for illustration
loss_output = quality_focal_loss(inputs, targets)

# Print output loss
print(loss_output)


tensor(0.3341, grad_fn=<MeanBackward0>)
