#!/usr/bin/env python import torch.nn.functional as F import torch.nn as nn def calc_vq_loss(pred, target, quant_loss, quant_loss_weight=1.0, alpha=1.0): """ function that computes the various components of the VQ loss """ rec_loss = nn.L1Loss()(pred, target) ## loss is VQ reconstruction + weighted pre-computed quantization loss quant_loss = quant_loss.mean() return quant_loss * quant_loss_weight + rec_loss, [rec_loss, quant_loss] def calc_logit_loss(pred, target): """ Cross entropy loss wrapper """ loss = F.cross_entropy(pred.reshape(-1, pred.size(-1)), target.reshape(-1)) return loss