import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models #################################################################################################### # adversarial loss for different gan mode #################################################################################################### class GANLoss(nn.Module): """Define different GAN objectives. The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input. """ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): """ Initialize the GANLoss class. Parameters: gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) self.gan_mode = gan_mode if gan_mode == 'lsgan': self.loss = nn.MSELoss() elif gan_mode == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif gan_mode == 'hinge': self.loss = nn.ReLU() elif gan_mode in ['wgangp', 'nonsaturating']: self.loss = None else: raise NotImplementedError('gan mode %s not implemented' % gan_mode) def get_target_tensor(self, prediction, target_is_real): """Create label tensors with the same size as the input. Parameters: prediction (tensor) - - tpyically the prediction from a discriminator target_is_real (bool) - - if the ground truth label is for real examples or fake examples Returns: A label tensor filled with ground truth label, and with the size of the input """ if target_is_real: target_tensor = self.real_label else: target_tensor = self.fake_label return target_tensor.expand_as(prediction) def calculate_loss(self, prediction, target_is_real, is_dis=False): """Calculate loss given Discriminator's output and grount truth labels. Parameters: prediction (tensor) - - tpyically the prediction output from a discriminator target_is_real (bool) - - if the ground truth label is for real examples or fake examples Returns: the calculated loss. """ if self.gan_mode in ['lsgan', 'vanilla']: target_tensor = self.get_target_tensor(prediction, target_is_real) loss = self.loss(prediction, target_tensor) if self.gan_mode == 'lsgan': loss = loss * 0.5 else: if is_dis: if target_is_real: prediction = -prediction if self.gan_mode == 'wgangp': loss = prediction.mean() elif self.gan_mode == 'nonsaturating': loss = F.softplus(prediction).mean() elif self.gan_mode == 'hinge': loss = self.loss(1+prediction).mean() else: if self.gan_mode == 'nonsaturating': loss = F.softplus(-prediction).mean() else: loss = -prediction.mean() return loss def __call__(self, predictions, target_is_real, is_dis=False): """Calculate loss for multi-scales gan""" if isinstance(predictions, list): losses = [] for prediction in predictions: losses.append(self.calculate_loss(prediction, target_is_real, is_dis)) loss = sum(losses) else: loss = self.calculate_loss(predictions, target_is_real, is_dis) return loss def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 Arguments: netD (network) -- discriminator network real_data (tensor array) -- real examples fake_data (tensor array) -- generated examples from the generator device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') type (str) -- if we mix real and fake data or not [real | fake | mixed]. constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 lambda_gp (float) -- weight for this loss Returns the gradient penalty loss """ if lambda_gp > 0.0: if type == 'real': # either use real examples, fake examples, or a linear interpolation of two. interpolatesv = real_data elif type == 'fake': interpolatesv = fake_data elif type == 'mixed': alpha = torch.rand(real_data.shape[0], 1, device=device) alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) else: raise NotImplementedError('{} not implemented'.format(type)) interpolatesv.requires_grad_(True) disc_interpolates = netD(interpolatesv) if isinstance(disc_interpolates, list): gradients = 0 for disc_interpolate in disc_interpolates: gradients += torch.autograd.grad(outputs=disc_interpolate, inputs=interpolatesv, grad_outputs=torch.ones(disc_interpolate.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] else: gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, grad_outputs=torch.ones(disc_interpolates.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(real_data.size(0), -1) # flat the data gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps return gradient_penalty, gradients else: return 0.0, None #################################################################################################### # trained LPIPS loss #################################################################################################### def normalize_tensor(x, eps=1e-10): norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) return x/(norm_factor+eps) def spatial_average(x, keepdim=True): return x.mean([2, 3], keepdim=keepdim) class NetLinLayer(nn.Module): """ A single linear layer which does a 1x1 conv """ def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(), ] if (use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) class LPIPSLoss(nn.Module): """ Learned perceptual metric https://github.com/richzhang/PerceptualSimilarity """ def __init__(self, use_dropout=True, ckpt_path=None): super(LPIPSLoss, self).__init__() self.path = ckpt_path self.net = VGG16() self.chns = [64, 128, 256, 512, 512] # vg16 features self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.load_from_pretrained() for param in self.parameters(): param.requires_grad = False def load_from_pretrained(self): self.load_state_dict(torch.load(self.path, map_location=torch.device("cpu")), strict=False) print("loaded pretrained LPIPS loss from {}".format(self.path)) def _get_features(self, vgg_f): names = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] feats = [] for i in range(len(names)): name = names[i] feat = vgg_f[name] feats.append(feat) return feats def forward(self, x, y): x_vgg, y_vgg = self._get_features(self.net(x)), self._get_features(self.net(y)) lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] reses = [] loss = 0 for i in range(len(self.chns)): x_feats, y_feats = normalize_tensor(x_vgg[i]), normalize_tensor(y_vgg[i]) diffs = (x_feats - y_feats) ** 2 res = spatial_average(lins[i].model(diffs)) loss += res reses.append(res) return loss class PerceptualLoss(nn.Module): r""" Perceptual loss, VGG-based https://arxiv.org/abs/1603.08155 https://github.com/dxyang/StyleTransfer/blob/master/utils.py """ def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 0.0]): super(PerceptualLoss, self).__init__() self.add_module('vgg', VGG16()) self.criterion = nn.L1Loss() self.weights = weights def __call__(self, x, y): # Compute features x_vgg, y_vgg = self.vgg(x), self.vgg(y) content_loss = 0.0 content_loss += self.weights[0] * self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2']) if self.weights[0] > 0 else 0 content_loss += self.weights[1] * self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2']) if self.weights[1] > 0 else 0 content_loss += self.weights[2] * self.criterion(x_vgg['relu3_3'], y_vgg['relu3_3']) if self.weights[2] > 0 else 0 content_loss += self.weights[3] * self.criterion(x_vgg['relu4_3'], y_vgg['relu4_3']) if self.weights[3] > 0 else 0 content_loss += self.weights[4] * self.criterion(x_vgg['relu5_3'], y_vgg['relu5_3']) if self.weights[4] > 0 else 0 return content_loss class Normalization(nn.Module): def __init__(self, device): super(Normalization, self).__init__() # .view the mean and std to make them [C x 1 x 1] so that they can # directly work with image Tensor of shape [B x C x H x W]. # B is batch size. C is number of channels. H is height and W is width. mean = torch.tensor([0.485, 0.456, 0.406]).to(device) std = torch.tensor([0.229, 0.224, 0.225]).to(device) self.mean = mean.view(-1, 1, 1) self.std = std.view(-1, 1, 1) def forward(self, img): # normalize img return (img - self.mean) / self.std class VGG16(nn.Module): def __init__(self): super(VGG16, self).__init__() features = models.vgg16(pretrained=True).features self.relu1_1 = torch.nn.Sequential() self.relu1_2 = torch.nn.Sequential() self.relu2_1 = torch.nn.Sequential() self.relu2_2 = torch.nn.Sequential() self.relu3_1 = torch.nn.Sequential() self.relu3_2 = torch.nn.Sequential() self.relu3_3 = torch.nn.Sequential() self.relu4_1 = torch.nn.Sequential() self.relu4_2 = torch.nn.Sequential() self.relu4_3 = torch.nn.Sequential() self.relu5_1 = torch.nn.Sequential() self.relu5_2 = torch.nn.Sequential() self.relu5_3 = torch.nn.Sequential() for x in range(2): self.relu1_1.add_module(str(x), features[x]) for x in range(2, 4): self.relu1_2.add_module(str(x), features[x]) for x in range(4, 7): self.relu2_1.add_module(str(x), features[x]) for x in range(7, 9): self.relu2_2.add_module(str(x), features[x]) for x in range(9, 12): self.relu3_1.add_module(str(x), features[x]) for x in range(12, 14): self.relu3_2.add_module(str(x), features[x]) for x in range(14, 16): self.relu3_3.add_module(str(x), features[x]) for x in range(16, 18): self.relu4_1.add_module(str(x), features[x]) for x in range(18, 21): self.relu4_2.add_module(str(x), features[x]) for x in range(21, 23): self.relu4_3.add_module(str(x), features[x]) for x in range(23, 26): self.relu5_1.add_module(str(x), features[x]) for x in range(26, 28): self.relu5_2.add_module(str(x), features[x]) for x in range(28, 30): self.relu5_3.add_module(str(x), features[x]) # don't need the gradients, just want the features for param in self.parameters(): param.requires_grad = False def forward(self, x,): relu1_1 = self.relu1_1(x) relu1_2 = self.relu1_2(relu1_1) relu2_1 = self.relu2_1(relu1_2) relu2_2 = self.relu2_2(relu2_1) relu3_1 = self.relu3_1(relu2_2) relu3_2 = self.relu3_2(relu3_1) relu3_3 = self.relu3_3(relu3_2) relu4_1 = self.relu4_1(relu3_3) relu4_2 = self.relu4_2(relu4_1) relu4_3 = self.relu4_3(relu4_2) relu5_1 = self.relu5_1(relu4_3) relu5_2 = self.relu5_2(relu5_1) relu5_3 = self.relu5_3(relu5_2) out = { 'relu1_1': relu1_1, 'relu1_2': relu1_2, 'relu2_1': relu2_1, 'relu2_2': relu2_2, 'relu3_1': relu3_1, 'relu3_2': relu3_2, 'relu3_3': relu3_3, 'relu4_1': relu4_1, 'relu4_2': relu4_2, 'relu4_3': relu4_3, 'relu5_1': relu5_1, 'relu5_2': relu5_2, 'relu5_3': relu5_3, } return out