""" Misc Losses Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ import torch import torch.nn as nn import torch.nn.functional as F from .builder import LOSSES @LOSSES.register_module() class CrossEntropyLoss(nn.Module): def __init__( self, weight=None, size_average=None, reduce=None, reduction="mean", label_smoothing=0.0, loss_weight=1.0, ignore_index=-1, ): super(CrossEntropyLoss, self).__init__() weight = torch.tensor(weight).cuda() if weight is not None else None self.loss_weight = loss_weight self.loss = nn.CrossEntropyLoss( weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce, reduction=reduction, label_smoothing=label_smoothing, ) def forward(self, pred, target): return self.loss(pred, target) * self.loss_weight @LOSSES.register_module() class SmoothCELoss(nn.Module): def __init__(self, smoothing_ratio=0.1): super(SmoothCELoss, self).__init__() self.smoothing_ratio = smoothing_ratio def forward(self, pred, target): eps = self.smoothing_ratio n_class = pred.size(1) one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1) one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) log_prb = F.log_softmax(pred, dim=1) loss = -(one_hot * log_prb).total(dim=1) loss = loss[torch.isfinite(loss)].mean() return loss @LOSSES.register_module() class BinaryFocalLoss(nn.Module): def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0): """Binary Focal Loss ` """ super(BinaryFocalLoss, self).__init__() assert 0 < alpha < 1 self.gamma = gamma self.alpha = alpha self.logits = logits self.reduce = reduce self.loss_weight = loss_weight def forward(self, pred, target, **kwargs): """Forward function. Args: pred (torch.Tensor): The prediction with shape (N) target (torch.Tensor): The ground truth. If containing class indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities, same shape as the input. Returns: torch.Tensor: The calculated loss """ if self.logits: bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none") else: bce = F.binary_cross_entropy(pred, target, reduction="none") pt = torch.exp(-bce) alpha = self.alpha * target + (1 - self.alpha) * (1 - target) focal_loss = alpha * (1 - pt) ** self.gamma * bce if self.reduce: focal_loss = torch.mean(focal_loss) return focal_loss * self.loss_weight @LOSSES.register_module() class FocalLoss(nn.Module): def __init__( self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1 ): """Focal Loss ` """ super(FocalLoss, self).__init__() assert reduction in ( "mean", "sum", ), "AssertionError: reduction should be 'mean' or 'sum'" assert isinstance( alpha, (float, list) ), "AssertionError: alpha should be of type float" assert isinstance(gamma, float), "AssertionError: gamma should be of type float" assert isinstance( loss_weight, float ), "AssertionError: loss_weight should be of type float" assert isinstance(ignore_index, int), "ignore_index must be of type int" self.gamma = gamma self.alpha = alpha self.reduction = reduction self.loss_weight = loss_weight self.ignore_index = ignore_index def forward(self, pred, target, **kwargs): """Forward function. Args: pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes. target (torch.Tensor): The ground truth. If containing class indices, shape (N) where each value is 0≤targets[i]≤C−1, If containing class probabilities, same shape as the input. Returns: torch.Tensor: The calculated loss """ # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] pred = pred.transpose(0, 1) # [C, B, d_1, d_2, ..., d_k] -> [C, N] pred = pred.reshape(pred.size(0), -1) # [C, N] -> [N, C] pred = pred.transpose(0, 1).contiguous() # (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,) target = target.view(-1).contiguous() assert pred.size(0) == target.size( 0 ), "The shape of pred doesn't match the shape of target" valid_mask = target != self.ignore_index target = target[valid_mask] pred = pred[valid_mask] if len(target) == 0: return 0.0 num_classes = pred.size(1) target = F.one_hot(target, num_classes=num_classes) alpha = self.alpha if isinstance(alpha, list): alpha = pred.new_tensor(alpha) pred_sigmoid = pred.sigmoid() target = target.type_as(pred) one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow( self.gamma ) loss = ( F.binary_cross_entropy_with_logits(pred, target, reduction="none") * focal_weight ) if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": loss = loss.total() return self.loss_weight * loss @LOSSES.register_module() class DiceLoss(nn.Module): def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1): """DiceLoss. This loss is proposed in `V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation `_. """ super(DiceLoss, self).__init__() self.smooth = smooth self.exponent = exponent self.loss_weight = loss_weight self.ignore_index = ignore_index def forward(self, pred, target, **kwargs): # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] pred = pred.transpose(0, 1) # [C, B, d_1, d_2, ..., d_k] -> [C, N] pred = pred.reshape(pred.size(0), -1) # [C, N] -> [N, C] pred = pred.transpose(0, 1).contiguous() # (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,) target = target.view(-1).contiguous() assert pred.size(0) == target.size( 0 ), "The shape of pred doesn't match the shape of target" valid_mask = target != self.ignore_index target = target[valid_mask] pred = pred[valid_mask] pred = F.softmax(pred, dim=1) num_classes = pred.shape[1] target = F.one_hot( torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes ) total_loss = 0 for i in range(num_classes): if i != self.ignore_index: num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth den = ( torch.sum( pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent) ) + self.smooth ) dice_loss = 1 - num / den total_loss += dice_loss loss = total_loss / num_classes return self.loss_weight * loss