Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Lovasz Loss | |
refer https://arxiv.org/abs/1705.08790 | |
Author: Xiaoyang Wu ([email protected]) | |
Please cite our work if the code is helpful to you. | |
""" | |
from typing import Optional | |
from itertools import filterfalse | |
import torch | |
import torch.nn.functional as F | |
from torch.nn.modules.loss import _Loss | |
from .builder import LOSSES | |
BINARY_MODE: str = "binary" | |
MULTICLASS_MODE: str = "multiclass" | |
MULTILABEL_MODE: str = "multilabel" | |
def _lovasz_grad(gt_sorted): | |
"""Compute gradient of the Lovasz extension w.r.t sorted errors | |
See Alg. 1 in paper | |
""" | |
p = len(gt_sorted) | |
gts = gt_sorted.sum() | |
intersection = gts - gt_sorted.float().cumsum(0) | |
union = gts + (1 - gt_sorted).float().cumsum(0) | |
jaccard = 1.0 - intersection / union | |
if p > 1: # cover 1-pixel case | |
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] | |
return jaccard | |
def _lovasz_hinge(logits, labels, per_image=True, ignore=None): | |
""" | |
Binary Lovasz hinge loss | |
logits: [B, H, W] Logits at each pixel (between -infinity and +infinity) | |
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) | |
per_image: compute the loss per image instead of per batch | |
ignore: void class id | |
""" | |
if per_image: | |
loss = mean( | |
_lovasz_hinge_flat( | |
*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore) | |
) | |
for log, lab in zip(logits, labels) | |
) | |
else: | |
loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore)) | |
return loss | |
def _lovasz_hinge_flat(logits, labels): | |
"""Binary Lovasz hinge loss | |
Args: | |
logits: [P] Logits at each prediction (between -infinity and +infinity) | |
labels: [P] Tensor, binary ground truth labels (0 or 1) | |
""" | |
if len(labels) == 0: | |
# only void pixels, the gradients should be 0 | |
return logits.sum() * 0.0 | |
signs = 2.0 * labels.float() - 1.0 | |
errors = 1.0 - logits * signs | |
errors_sorted, perm = torch.sort(errors, dim=0, descending=True) | |
perm = perm.data | |
gt_sorted = labels[perm] | |
grad = _lovasz_grad(gt_sorted) | |
loss = torch.dot(F.relu(errors_sorted), grad) | |
return loss | |
def _flatten_binary_scores(scores, labels, ignore=None): | |
"""Flattens predictions in the batch (binary case) | |
Remove labels equal to 'ignore' | |
""" | |
scores = scores.view(-1) | |
labels = labels.view(-1) | |
if ignore is None: | |
return scores, labels | |
valid = labels != ignore | |
vscores = scores[valid] | |
vlabels = labels[valid] | |
return vscores, vlabels | |
def _lovasz_softmax( | |
probas, labels, classes="present", class_seen=None, per_image=False, ignore=None | |
): | |
"""Multi-class Lovasz-Softmax loss | |
Args: | |
@param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1). | |
Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. | |
@param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) | |
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. | |
@param per_image: compute the loss per image instead of per batch | |
@param ignore: void class labels | |
""" | |
if per_image: | |
loss = mean( | |
_lovasz_softmax_flat( | |
*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), | |
classes=classes | |
) | |
for prob, lab in zip(probas, labels) | |
) | |
else: | |
loss = _lovasz_softmax_flat( | |
*_flatten_probas(probas, labels, ignore), | |
classes=classes, | |
class_seen=class_seen | |
) | |
return loss | |
def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None): | |
"""Multi-class Lovasz-Softmax loss | |
Args: | |
@param probas: [P, C] Class probabilities at each prediction (between 0 and 1) | |
@param labels: [P] Tensor, ground truth labels (between 0 and C - 1) | |
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. | |
""" | |
if probas.numel() == 0: | |
# only void pixels, the gradients should be 0 | |
return probas * 0.0 | |
C = probas.size(1) | |
losses = [] | |
class_to_sum = list(range(C)) if classes in ["all", "present"] else classes | |
# for c in class_to_sum: | |
for c in labels.unique(): | |
if class_seen is None: | |
fg = (labels == c).type_as(probas) # foreground for class c | |
if classes == "present" and fg.sum() == 0: | |
continue | |
if C == 1: | |
if len(classes) > 1: | |
raise ValueError("Sigmoid output possible only with 1 class") | |
class_pred = probas[:, 0] | |
else: | |
class_pred = probas[:, c] | |
errors = (fg - class_pred).abs() | |
errors_sorted, perm = torch.sort(errors, 0, descending=True) | |
perm = perm.data | |
fg_sorted = fg[perm] | |
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) | |
else: | |
if c in class_seen: | |
fg = (labels == c).type_as(probas) # foreground for class c | |
if classes == "present" and fg.sum() == 0: | |
continue | |
if C == 1: | |
if len(classes) > 1: | |
raise ValueError("Sigmoid output possible only with 1 class") | |
class_pred = probas[:, 0] | |
else: | |
class_pred = probas[:, c] | |
errors = (fg - class_pred).abs() | |
errors_sorted, perm = torch.sort(errors, 0, descending=True) | |
perm = perm.data | |
fg_sorted = fg[perm] | |
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) | |
return mean(losses) | |
def _flatten_probas(probas, labels, ignore=None): | |
"""Flattens predictions in the batch""" | |
if probas.dim() == 3: | |
# assumes output of a sigmoid layer | |
B, H, W = probas.size() | |
probas = probas.view(B, 1, H, W) | |
C = probas.size(1) | |
probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C] | |
probas = probas.contiguous().view(-1, C) # [P, C] | |
labels = labels.view(-1) | |
if ignore is None: | |
return probas, labels | |
valid = labels != ignore | |
vprobas = probas[valid] | |
vlabels = labels[valid] | |
return vprobas, vlabels | |
def isnan(x): | |
return x != x | |
def mean(values, ignore_nan=False, empty=0): | |
"""Nan-mean compatible with generators.""" | |
values = iter(values) | |
if ignore_nan: | |
values = filterfalse(isnan, values) | |
try: | |
n = 1 | |
acc = next(values) | |
except StopIteration: | |
if empty == "raise": | |
raise ValueError("Empty mean") | |
return empty | |
for n, v in enumerate(values, 2): | |
acc += v | |
if n == 1: | |
return acc | |
return acc / n | |
class LovaszLoss(_Loss): | |
def __init__( | |
self, | |
mode: str, | |
class_seen: Optional[int] = None, | |
per_image: bool = False, | |
ignore_index: Optional[int] = None, | |
loss_weight: float = 1.0, | |
): | |
"""Lovasz loss for segmentation task. | |
It supports binary, multiclass and multilabel cases | |
Args: | |
mode: Loss mode 'binary', 'multiclass' or 'multilabel' | |
ignore_index: Label that indicates ignored pixels (does not contribute to loss) | |
per_image: If True loss computed per each image and then averaged, else computed per whole batch | |
Shape | |
- **y_pred** - torch.Tensor of shape (N, C, H, W) | |
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) | |
Reference | |
https://github.com/BloodAxe/pytorch-toolbelt | |
""" | |
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} | |
super().__init__() | |
self.mode = mode | |
self.ignore_index = ignore_index | |
self.per_image = per_image | |
self.class_seen = class_seen | |
self.loss_weight = loss_weight | |
def forward(self, y_pred, y_true): | |
if self.mode in {BINARY_MODE, MULTILABEL_MODE}: | |
loss = _lovasz_hinge( | |
y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index | |
) | |
elif self.mode == MULTICLASS_MODE: | |
y_pred = y_pred.softmax(dim=1) | |
loss = _lovasz_softmax( | |
y_pred, | |
y_true, | |
class_seen=self.class_seen, | |
per_image=self.per_image, | |
ignore=self.ignore_index, | |
) | |
else: | |
raise ValueError("Wrong mode {}.".format(self.mode)) | |
return loss * self.loss_weight | |