|
import math |
|
import sys |
|
from collections import OrderedDict |
|
|
|
sys.path.append('..') |
|
import lpips |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision.utils import save_image |
|
|
|
from models.archs.vqgan_arch import (Decoder, DecoderRes, Discriminator, |
|
Encoder, |
|
VectorQuantizerSpatialTextureAware, |
|
VectorQuantizerTexture) |
|
from models.losses.vqgan_loss import (DiffAugment, adopt_weight, |
|
calculate_adaptive_weight, hinge_d_loss) |
|
|
|
|
|
class HierarchyVQSpatialTextureAwareModel(): |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
self.device = torch.device('cuda') |
|
self.top_encoder = Encoder( |
|
ch=opt['top_ch'], |
|
num_res_blocks=opt['top_num_res_blocks'], |
|
attn_resolutions=opt['top_attn_resolutions'], |
|
ch_mult=opt['top_ch_mult'], |
|
in_channels=opt['top_in_channels'], |
|
resolution=opt['top_resolution'], |
|
z_channels=opt['top_z_channels'], |
|
double_z=opt['top_double_z'], |
|
dropout=opt['top_dropout']).to(self.device) |
|
self.decoder = Decoder( |
|
in_channels=opt['top_in_channels'], |
|
resolution=opt['top_resolution'], |
|
z_channels=opt['top_z_channels'], |
|
ch=opt['top_ch'], |
|
out_ch=opt['top_out_ch'], |
|
num_res_blocks=opt['top_num_res_blocks'], |
|
attn_resolutions=opt['top_attn_resolutions'], |
|
ch_mult=opt['top_ch_mult'], |
|
dropout=opt['top_dropout'], |
|
resamp_with_conv=True, |
|
give_pre_end=False).to(self.device) |
|
self.top_quantize = VectorQuantizerTexture( |
|
1024, opt['embed_dim'], beta=0.25).to(self.device) |
|
self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"], |
|
opt['embed_dim'], |
|
1).to(self.device) |
|
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], |
|
opt["top_z_channels"], |
|
1).to(self.device) |
|
self.load_top_pretrain_models() |
|
|
|
self.bot_encoder = Encoder( |
|
ch=opt['bot_ch'], |
|
num_res_blocks=opt['bot_num_res_blocks'], |
|
attn_resolutions=opt['bot_attn_resolutions'], |
|
ch_mult=opt['bot_ch_mult'], |
|
in_channels=opt['bot_in_channels'], |
|
resolution=opt['bot_resolution'], |
|
z_channels=opt['bot_z_channels'], |
|
double_z=opt['bot_double_z'], |
|
dropout=opt['bot_dropout']).to(self.device) |
|
self.bot_decoder_res = DecoderRes( |
|
in_channels=opt['bot_in_channels'], |
|
resolution=opt['bot_resolution'], |
|
z_channels=opt['bot_z_channels'], |
|
ch=opt['bot_ch'], |
|
num_res_blocks=opt['bot_num_res_blocks'], |
|
ch_mult=opt['bot_ch_mult'], |
|
dropout=opt['bot_dropout'], |
|
give_pre_end=False).to(self.device) |
|
self.bot_quantize = VectorQuantizerSpatialTextureAware( |
|
opt['bot_n_embed'], |
|
opt['embed_dim'], |
|
beta=0.25, |
|
spatial_size=opt['codebook_spatial_size']).to(self.device) |
|
self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"], |
|
opt['embed_dim'], |
|
1).to(self.device) |
|
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], |
|
opt["bot_z_channels"], |
|
1).to(self.device) |
|
|
|
self.disc = Discriminator( |
|
opt['n_channels'], opt['ndf'], |
|
n_layers=opt['disc_layers']).to(self.device) |
|
self.perceptual = lpips.LPIPS(net="vgg").to(self.device) |
|
self.perceptual_weight = opt['perceptual_weight'] |
|
self.disc_start_step = opt['disc_start_step'] |
|
self.disc_weight_max = opt['disc_weight_max'] |
|
self.diff_aug = opt['diff_aug'] |
|
self.policy = "color,translation" |
|
|
|
self.load_discriminator_models() |
|
|
|
self.disc.train() |
|
|
|
self.fix_decoder = opt['fix_decoder'] |
|
|
|
self.init_training_settings() |
|
|
|
def load_top_pretrain_models(self): |
|
|
|
top_vae_checkpoint = torch.load(self.opt['top_vae_path']) |
|
self.top_encoder.load_state_dict( |
|
top_vae_checkpoint['encoder'], strict=True) |
|
self.decoder.load_state_dict( |
|
top_vae_checkpoint['decoder'], strict=True) |
|
self.top_quantize.load_state_dict( |
|
top_vae_checkpoint['quantize'], strict=True) |
|
self.top_quant_conv.load_state_dict( |
|
top_vae_checkpoint['quant_conv'], strict=True) |
|
self.top_post_quant_conv.load_state_dict( |
|
top_vae_checkpoint['post_quant_conv'], strict=True) |
|
self.top_encoder.eval() |
|
self.top_quantize.eval() |
|
self.top_quant_conv.eval() |
|
self.top_post_quant_conv.eval() |
|
|
|
def init_training_settings(self): |
|
self.log_dict = OrderedDict() |
|
self.configure_optimizers() |
|
|
|
def configure_optimizers(self): |
|
optim_params = [] |
|
for v in self.bot_encoder.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
for v in self.bot_decoder_res.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
for v in self.bot_quantize.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
for v in self.bot_quant_conv.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
for v in self.bot_post_quant_conv.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
if not self.fix_decoder: |
|
for name, v in self.decoder.named_parameters(): |
|
if v.requires_grad: |
|
if 'up.0' in name: |
|
optim_params.append(v) |
|
if 'up.1' in name: |
|
optim_params.append(v) |
|
if 'up.2' in name: |
|
optim_params.append(v) |
|
if 'up.3' in name: |
|
optim_params.append(v) |
|
|
|
self.optimizer = torch.optim.Adam(optim_params, lr=self.opt['lr']) |
|
|
|
self.disc_optimizer = torch.optim.Adam( |
|
self.disc.parameters(), lr=self.opt['lr']) |
|
|
|
def load_discriminator_models(self): |
|
|
|
top_vae_checkpoint = torch.load(self.opt['top_vae_path']) |
|
self.disc.load_state_dict( |
|
top_vae_checkpoint['discriminator'], strict=True) |
|
|
|
def save_network(self, save_path): |
|
"""Save networks. |
|
""" |
|
|
|
save_dict = {} |
|
save_dict['bot_encoder'] = self.bot_encoder.state_dict() |
|
save_dict['bot_decoder_res'] = self.bot_decoder_res.state_dict() |
|
save_dict['decoder'] = self.decoder.state_dict() |
|
save_dict['bot_quantize'] = self.bot_quantize.state_dict() |
|
save_dict['bot_quant_conv'] = self.bot_quant_conv.state_dict() |
|
save_dict['bot_post_quant_conv'] = self.bot_post_quant_conv.state_dict( |
|
) |
|
save_dict['discriminator'] = self.disc.state_dict() |
|
torch.save(save_dict, save_path) |
|
|
|
def load_network(self): |
|
checkpoint = torch.load(self.opt['pretrained_models']) |
|
self.bot_encoder.load_state_dict( |
|
checkpoint['bot_encoder'], strict=True) |
|
self.bot_decoder_res.load_state_dict( |
|
checkpoint['bot_decoder_res'], strict=True) |
|
self.decoder.load_state_dict(checkpoint['decoder'], strict=True) |
|
self.bot_quantize.load_state_dict( |
|
checkpoint['bot_quantize'], strict=True) |
|
self.bot_quant_conv.load_state_dict( |
|
checkpoint['bot_quant_conv'], strict=True) |
|
self.bot_post_quant_conv.load_state_dict( |
|
checkpoint['bot_post_quant_conv'], strict=True) |
|
|
|
def optimize_parameters(self, data, step): |
|
self.bot_encoder.train() |
|
self.bot_decoder_res.train() |
|
if not self.fix_decoder: |
|
self.decoder.train() |
|
self.bot_quantize.train() |
|
self.bot_quant_conv.train() |
|
self.bot_post_quant_conv.train() |
|
|
|
loss, d_loss = self.training_step(data, step) |
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
if step > self.disc_start_step: |
|
self.disc_optimizer.zero_grad() |
|
d_loss.backward() |
|
self.disc_optimizer.step() |
|
|
|
def top_encode(self, x, mask): |
|
h = self.top_encoder(x) |
|
h = self.top_quant_conv(h) |
|
quant, _, _ = self.top_quantize(h, mask) |
|
quant = self.top_post_quant_conv(quant) |
|
return quant |
|
|
|
def bot_encode(self, x, mask): |
|
h = self.bot_encoder(x) |
|
h = self.bot_quant_conv(h) |
|
quant, emb_loss, info = self.bot_quantize(h, mask) |
|
quant = self.bot_post_quant_conv(quant) |
|
bot_dec_res = self.bot_decoder_res(quant) |
|
return bot_dec_res, emb_loss, info |
|
|
|
def decode(self, quant_top, bot_dec_res): |
|
dec = self.decoder(quant_top, bot_h=bot_dec_res) |
|
return dec |
|
|
|
def forward_step(self, input, mask): |
|
with torch.no_grad(): |
|
quant_top = self.top_encode(input, mask) |
|
bot_dec_res, diff, _ = self.bot_encode(input, mask) |
|
dec = self.decode(quant_top, bot_dec_res) |
|
return dec, diff |
|
|
|
def feed_data(self, data): |
|
x = data['image'].float().to(self.device) |
|
mask = data['texture_mask'].float().to(self.device) |
|
|
|
return x, mask |
|
|
|
def training_step(self, data, step): |
|
x, mask = self.feed_data(data) |
|
xrec, codebook_loss = self.forward_step(x, mask) |
|
|
|
|
|
recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) |
|
p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) |
|
nll_loss = recon_loss + self.perceptual_weight * p_loss |
|
nll_loss = torch.mean(nll_loss) |
|
|
|
|
|
if self.diff_aug: |
|
xrec = DiffAugment(xrec, policy=self.policy) |
|
|
|
|
|
logits_fake = self.disc(xrec) |
|
g_loss = -torch.mean(logits_fake) |
|
last_layer = self.decoder.conv_out.weight |
|
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer, |
|
self.disc_weight_max) |
|
d_weight *= adopt_weight(1, step, self.disc_start_step) |
|
loss = nll_loss + d_weight * g_loss + codebook_loss |
|
|
|
self.log_dict["loss"] = loss |
|
self.log_dict["l1"] = recon_loss.mean().item() |
|
self.log_dict["perceptual"] = p_loss.mean().item() |
|
self.log_dict["nll_loss"] = nll_loss.item() |
|
self.log_dict["g_loss"] = g_loss.item() |
|
self.log_dict["d_weight"] = d_weight |
|
self.log_dict["codebook_loss"] = codebook_loss.item() |
|
|
|
if step > self.disc_start_step: |
|
if self.diff_aug: |
|
logits_real = self.disc( |
|
DiffAugment(x.contiguous().detach(), policy=self.policy)) |
|
else: |
|
logits_real = self.disc(x.contiguous().detach()) |
|
logits_fake = self.disc(xrec.contiguous().detach( |
|
)) |
|
d_loss = hinge_d_loss(logits_real, logits_fake) |
|
self.log_dict["d_loss"] = d_loss |
|
else: |
|
d_loss = None |
|
|
|
return loss, d_loss |
|
|
|
@torch.no_grad() |
|
def inference(self, data_loader, save_dir): |
|
self.bot_encoder.eval() |
|
self.bot_decoder_res.eval() |
|
self.decoder.eval() |
|
self.bot_quantize.eval() |
|
self.bot_quant_conv.eval() |
|
self.bot_post_quant_conv.eval() |
|
|
|
loss_total = 0 |
|
num = 0 |
|
|
|
for _, data in enumerate(data_loader): |
|
img_name = data['img_name'][0] |
|
x, mask = self.feed_data(data) |
|
xrec, _ = self.forward_step(x, mask) |
|
|
|
recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) |
|
p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) |
|
nll_loss = recon_loss + self.perceptual_weight * p_loss |
|
nll_loss = torch.mean(nll_loss) |
|
loss_total += nll_loss |
|
|
|
num += x.size(0) |
|
|
|
if x.shape[1] > 3: |
|
|
|
assert xrec.shape[1] > 3 |
|
|
|
xrec = torch.argmax(xrec, dim=1, keepdim=True) |
|
xrec = F.one_hot(xrec, num_classes=x.shape[1]) |
|
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() |
|
x = self.to_rgb(x) |
|
xrec = self.to_rgb(xrec) |
|
|
|
img_cat = torch.cat([x, xrec], dim=3).detach() |
|
img_cat = ((img_cat + 1) / 2) |
|
img_cat = img_cat.clamp_(0, 1) |
|
save_image( |
|
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4) |
|
|
|
return (loss_total / num).item() |
|
|
|
def get_current_log(self): |
|
return self.log_dict |
|
|
|
def update_learning_rate(self, epoch): |
|
"""Update learning rate. |
|
|
|
Args: |
|
current_iter (int): Current iteration. |
|
warmup_iter (int): Warmup iter numbers. -1 for no warmup. |
|
Default: -1. |
|
""" |
|
lr = self.optimizer.param_groups[0]['lr'] |
|
|
|
if self.opt['lr_decay'] == 'step': |
|
lr = self.opt['lr'] * ( |
|
self.opt['gamma']**(epoch // self.opt['step'])) |
|
elif self.opt['lr_decay'] == 'cos': |
|
lr = self.opt['lr'] * ( |
|
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 |
|
elif self.opt['lr_decay'] == 'linear': |
|
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) |
|
elif self.opt['lr_decay'] == 'linear2exp': |
|
if epoch < self.opt['turning_point'] + 1: |
|
|
|
|
|
lr = self.opt['lr'] * ( |
|
1 - epoch / int(self.opt['turning_point'] * 1.0526)) |
|
else: |
|
lr *= self.opt['gamma'] |
|
elif self.opt['lr_decay'] == 'schedule': |
|
if epoch in self.opt['schedule']: |
|
lr *= self.opt['gamma'] |
|
else: |
|
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) |
|
|
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
return lr |
|
|