import torch |
import torch.nn as nn |
from torch.autograd import Function |
from torch.cuda.amp import custom_fwd, custom_bwd |
from onmt.modules.sparse_activations import _threshold_and_support |
class SparsemaxLossFunction(Function): |
@staticmethod |
@custom_fwd |
def forward(ctx, input, target): |
""" |
input (FloatTensor): ``(n, num_classes)``. |
target (LongTensor): ``(n,)``, the indices of the target classes |
""" |
input_batch, classes = input.size() |
z_k = input.gather(1, target.unsqueeze(1)).squeeze() |
tau_z, support_size = _threshold_and_support(input, dim=1) |
support = input > tau_z |
x = torch.where( |
support, input**2 - tau_z**2, torch.tensor(0.0, device=input.device) |
).sum(dim=1) |
ctx.save_for_backward(input, target, tau_z) |
return torch.clamp(x / 2 - z_k + 0.5, min=0.0) |
@staticmethod |
@custom_bwd |
def backward(ctx, grad_output): |
input, target, tau_z = ctx.saved_tensors |
sparsemax_out = torch.clamp(input - tau_z, min=0) |
delta = torch.zeros_like(sparsemax_out) |
delta.scatter_(1, target.unsqueeze(1), 1) |
return sparsemax_out - delta, None |
sparsemax_loss = SparsemaxLossFunction.apply |
class SparsemaxLoss(nn.Module): |
""" |
An implementation of sparsemax loss, first proposed in |
:cite:`DBLP:journals/corr/MartinsA16`. If using |
a sparse output layer, it is not possible to use negative log likelihood |
because the loss is infinite in the case the target is assigned zero |
probability. Inputs to SparsemaxLoss are arbitrary dense real-valued |
vectors (like in nn.CrossEntropyLoss), not probability vectors (like in |
nn.NLLLoss). |
""" |
def __init__(self, weight=None, ignore_index=-100, reduction="elementwise_mean"): |
assert reduction in ["elementwise_mean", "sum", "none"] |
self.reduction = reduction |
self.weight = weight |
self.ignore_index = ignore_index |
super(SparsemaxLoss, self).__init__() |
def forward(self, input, target): |
loss = sparsemax_loss(input, target) |
if self.ignore_index >= 0: |
ignored_positions = target == self.ignore_index |
size = float((target.size(0) - ignored_positions.sum()).item()) |
loss.masked_fill_(ignored_positions, 0.0) |
else: |
size = float(target.size(0)) |
if self.reduction == "sum": |
loss = loss.sum() |
elif self.reduction == "elementwise_mean": |
loss = loss.sum() / size |
return loss |