ReactSeq / onmt /modules /sparse_losses.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
2.75 kB
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd
from onmt.modules.sparse_activations import _threshold_and_support
class SparsemaxLossFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx, input, target):
"""
input (FloatTensor): ``(n, num_classes)``.
target (LongTensor): ``(n,)``, the indices of the target classes
"""
input_batch, classes = input.size()
z_k = input.gather(1, target.unsqueeze(1)).squeeze()
tau_z, support_size = _threshold_and_support(input, dim=1)
support = input > tau_z
x = torch.where(
support, input**2 - tau_z**2, torch.tensor(0.0, device=input.device)
).sum(dim=1)
ctx.save_for_backward(input, target, tau_z)
# clamping necessary because of numerical errors: loss should be lower
# bounded by zero, but negative values near zero are possible without
# the clamp
return torch.clamp(x / 2 - z_k + 0.5, min=0.0)
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input, target, tau_z = ctx.saved_tensors
sparsemax_out = torch.clamp(input - tau_z, min=0)
delta = torch.zeros_like(sparsemax_out)
delta.scatter_(1, target.unsqueeze(1), 1)
return sparsemax_out - delta, None
sparsemax_loss = SparsemaxLossFunction.apply
class SparsemaxLoss(nn.Module):
"""
An implementation of sparsemax loss, first proposed in
:cite:`DBLP:journals/corr/MartinsA16`. If using
a sparse output layer, it is not possible to use negative log likelihood
because the loss is infinite in the case the target is assigned zero
probability. Inputs to SparsemaxLoss are arbitrary dense real-valued
vectors (like in nn.CrossEntropyLoss), not probability vectors (like in
nn.NLLLoss).
"""
def __init__(self, weight=None, ignore_index=-100, reduction="elementwise_mean"):
assert reduction in ["elementwise_mean", "sum", "none"]
self.reduction = reduction
self.weight = weight
self.ignore_index = ignore_index
super(SparsemaxLoss, self).__init__()
def forward(self, input, target):
loss = sparsemax_loss(input, target)
if self.ignore_index >= 0:
ignored_positions = target == self.ignore_index
size = float((target.size(0) - ignored_positions.sum()).item())
loss.masked_fill_(ignored_positions, 0.0)
else:
size = float(target.size(0))
if self.reduction == "sum":
loss = loss.sum()
elif self.reduction == "elementwise_mean":
loss = loss.sum() / size
return loss