CLIPInverter / models /e4e_features.py
Canberk Baykal
app.py
b5ed368
raw
history blame
2.56 kB
import matplotlib
matplotlib.use('Agg')
import torch
from torch import nn
from models.encoders import psp_encoders_features
def get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
return d_filt
class pSp(nn.Module):
def __init__(self, opts):
super(pSp, self).__init__()
self.opts = opts
# Define architecture
self.encoder = self.set_encoder().eval()
# Load weights if needed
self.load_weights()
def set_encoder(self):
encoder = psp_encoders_features.Encoder4Editing(50, 'ir_se', self.opts)
return encoder
def load_weights(self):
# We only load the encoder weights
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.pretrained_e4e_path))
ckpt = torch.load(self.opts.pretrained_e4e_path, map_location='cpu')
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
self.__load_latent_avg(ckpt)
def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
inject_latent=None, return_latents=False, alpha=None):
if input_code:
codes = x
else:
codes, features = self.encoder(x)
# normalize with respect to the center of an average face
if self.opts.start_from_latent_avg:
if codes.ndim == 2:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
else:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
if latent_mask is not None:
for i in latent_mask:
if inject_latent is not None:
if alpha is not None:
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
else:
codes[:, i] = inject_latent[:, i]
else:
codes[:, i] = 0
return codes, features
# Forward the modulated feature maps
def forward_features(self, features):
return self.encoder.forward_features(features)
def __load_latent_avg(self, ckpt, repeat=None):
if 'latent_avg' in ckpt:
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
if repeat is not None:
self.latent_avg = self.latent_avg.repeat(repeat, 1)
else:
self.latent_avg = None