CLIPInverter / app.py
johnberg's picture
add example input
acbc4ce
raw
history blame
2.6 kB
import torch
from argparse import Namespace
import torchvision.transforms as transforms
import clip
import numpy as np
import sys
sys.path.append(".")
sys.path.append("..")
from models.e4e_features import pSp
from adapter.adapter_decoder import CLIPAdapterWithDecoder
import gradio as gr
def tensor2im(var):
var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
var = ((var + 1) / 2)
var[var < 0] = 0
var[var > 1] = 1
var = var * 255
return var.astype('uint8')
def run_alignment(image_path):
import dlib
from align_faces_parallel import align_face
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
aligned_image = align_face(image_path, predictor=predictor)
# print("Aligned image has shape: {}".format(aligned_image.size))
return aligned_image
input_transforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
model_path = 'pretrained_faces.pt'
e4e_path = 'e4e_ffhq_encode.pt'
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = model_path
opts['pretrained_e4e_path'] = e4e_path
device = 'cuda' if torch.cuda.is_available() else 'cpu'
opts['device'] = device
opts = Namespace(**opts)
encoder = pSp(opts)
encoder.eval()
encoder.to(device)
adapter = CLIPAdapterWithDecoder(opts)
adapter.eval()
adapter.to(device)
clip_model, _ = clip.load("ViT-B/32", device=device)
def manipulate(input_image, caption):
aligned_image = run_alignment(input_image)
input_image = input_transforms(aligned_image)
input_image = input_image.unsqueeze(0)
text_input = clip.tokenize(caption)
text_input = text_input.to(device)
input_image = input_image.to(device).float()
with torch.no_grad():
text_features = clip_model.encode_text(text_input).float()
w, features = encoder.forward(input_image, return_latents=True)
features = adapter.adapter(features, text_features)
w_hat = w + 0.1 * encoder.forward_features(features)
result_tensor, _ = adapter.decoder([w_hat], input_is_latent=True, return_latents=False, randomize_noise=False, truncation=1, txt_embed=text_features)
result_tensor = result_tensor.squeeze(0)
result_image = tensor2im(result_tensor)
return result_image
gr.Interface(fn=manipulate,
inputs=[gr.Image(type="pil"), "text"],
outputs="image",
examples=[['example.jpg', "He has mustache"]],
title="CLIPInverter").launch()