File size: 290 Bytes
55890ea
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from .HardNegativeNLLLoss import HardNegativeNLLLoss


def load_loss(loss_class, *args, **kwargs):
    if loss_class == "HardNegativeNLLLoss":
        loss_cls = HardNegativeNLLLoss
    else:
        raise ValueError(f"Unknown loss class {loss_class}")
    return loss_cls(*args, **kwargs)