|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch.nn import functional as F |
|
|
|
|
|
class FocalLoss(torch.nn.Module): |
|
"""Multi-class Focal loss implementation""" |
|
|
|
def __init__(self, gamma=2, weight=None, ignore_index=-100): |
|
super(FocalLoss, self).__init__() |
|
self.gamma = gamma |
|
self.weight = weight |
|
self.ignore_index = ignore_index |
|
|
|
def forward(self, input, target): |
|
""" |
|
input: [N, C] |
|
target: [N, ] |
|
""" |
|
logpt = F.log_softmax(input, dim=1) |
|
pt = torch.exp(logpt) |
|
logpt = (1-pt)**self.gamma * logpt |
|
loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index) |
|
return loss |
|
|
|
|
|
|
|
|
|
class LabelSmoothingCorrectionCrossEntropy(torch.nn.Module): |
|
def __init__(self, eps=0.1, reduction='mean', ignore_index=-100): |
|
super(LabelSmoothingCorrectionCrossEntropy, self).__init__() |
|
self.eps = eps |
|
self.reduction = reduction |
|
self.ignore_index = ignore_index |
|
|
|
def forward(self, output, target): |
|
c = output.size()[-1] |
|
log_preds = F.log_softmax(output, dim=-1) |
|
if self.reduction == 'sum': |
|
loss = -log_preds.sum() |
|
else: |
|
loss = -log_preds.sum(dim=-1) |
|
if self.reduction == 'mean': |
|
loss = loss.mean() |
|
|
|
|
|
labels_hat = torch.argmax(output, dim=1) |
|
lt_sum = labels_hat + target |
|
abs_lt_sub = abs(labels_hat - target) |
|
correction_loss = 0 |
|
for i in range(c): |
|
if lt_sum[i] == 0: |
|
pass |
|
elif lt_sum[i] == 1: |
|
if abs_lt_sub[i] == 1: |
|
pass |
|
else: |
|
correction_loss -= self.eps*(0.5945275813408382) |
|
else: |
|
correction_loss += self.eps*(1/0.32447699714575207) |
|
correction_loss /= c |
|
|
|
return loss*self.eps/c + (1-self.eps) * \ |
|
F.nll_loss(log_preds, target, reduction=self.reduction, ignore_index=self.ignore_index) + correction_loss |
|
|