|
|
|
import os |
|
import gradio as gr |
|
|
|
MODEL_DIR = 'models/pretrain' |
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
|
|
os.system("wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/jzhubt_connect_ust_hk/ETYVen9KXGlAia2gH6pcZswB9Lw-21vWrE75OACvG2SBow\?e\=SCGqg0\&download=1 -O $MODEL_DIR/stylegan2-ffhq-config-f-1024x1024.pth --quiet") |
|
|
|
|
|
|
|
"""Demo.""" |
|
import io |
|
import cv2 |
|
import warnings |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from models import build_model |
|
|
|
warnings.filterwarnings(action='ignore', category=UserWarning) |
|
|
|
def postprocess_image(image, min_val=-1.0, max_val=1.0): |
|
"""Post-processes image to pixel range [0, 255] with dtype `uint8`. |
|
|
|
This function is particularly used to handle the results produced by deep |
|
models. |
|
|
|
NOTE: The input image is assumed to be with format `NCHW`, and the returned |
|
image will always be with format `NHWC`. |
|
|
|
Args: |
|
image: The input image for post-processing. |
|
min_val: Expected minimum value of the input image. |
|
max_val: Expected maximum value of the input image. |
|
|
|
Returns: |
|
The post-processed image. |
|
""" |
|
assert isinstance(image, np.ndarray) |
|
|
|
image = image.astype(np.float64) |
|
image = (image - min_val) / (max_val - min_val) * 255 |
|
image = np.clip(image + 0.5, 0, 255).astype(np.uint8) |
|
|
|
assert image.ndim == 4 and image.shape[1] in [1, 3, 4] |
|
return image.transpose(0, 2, 3, 1) |
|
|
|
|
|
def to_numpy(data): |
|
"""Converts the input data to `numpy.ndarray`.""" |
|
if isinstance(data, (int, float)): |
|
return np.array(data) |
|
if isinstance(data, np.ndarray): |
|
return data |
|
if isinstance(data, torch.Tensor): |
|
return data.detach().cpu().numpy() |
|
raise TypeError(f'Not supported data type `{type(data)}` for ' |
|
f'converting to `numpy.ndarray`!') |
|
|
|
|
|
def linear_interpolate(latent_code, |
|
boundary, |
|
layer_index=None, |
|
start_distance=-10.0, |
|
end_distance=10.0, |
|
steps=7): |
|
"""Interpolate between the latent code and boundary.""" |
|
assert (len(latent_code.shape) == 3 and len(boundary.shape) == 3 and |
|
latent_code.shape[0] == 1 and boundary.shape[0] == 1 and |
|
latent_code.shape[1] == boundary.shape[1]) |
|
linspace = np.linspace(start_distance, end_distance, steps) |
|
linspace = linspace.reshape([-1, 1, 1]).astype(np.float32) |
|
inter_code = linspace * boundary |
|
is_manipulatable = np.zeros(inter_code.shape, dtype=bool) |
|
is_manipulatable[:, layer_index, :] = True |
|
mani_code = np.where(is_manipulatable, latent_code+inter_code, latent_code) |
|
return mani_code |
|
|
|
|
|
def imshow(images, col, viz_size=256): |
|
"""Shows images in one figure.""" |
|
num, height, width, channels = images.shape |
|
assert num % col == 0 |
|
row = num // col |
|
|
|
fused_image = np.zeros((viz_size*row, viz_size*col, channels), dtype=np.uint8) |
|
|
|
for idx, image in enumerate(images): |
|
i, j = divmod(idx, col) |
|
y = i * viz_size |
|
x = j * viz_size |
|
if height != viz_size or width != viz_size: |
|
image = cv2.resize(image, (viz_size, viz_size)) |
|
fused_image[y:y + viz_size, x:x + viz_size] = image |
|
|
|
fused_image = np.asarray(fused_image, dtype=np.uint8) |
|
data = io.BytesIO() |
|
if channels == 4: |
|
Image.fromarray(fused_image).save(data, 'png') |
|
elif channels == 3: |
|
Image.fromarray(fused_image).save(data, 'jpeg') |
|
else: |
|
raise ValueError('Image channel error') |
|
im_data = data.getvalue() |
|
image = Image.open(io.BytesIO(im_data)) |
|
return image |
|
|
|
print('Building generator') |
|
|
|
checkpoint_path=f'{MODEL_DIR}/stylegan2-ffhq-config-f-1024x1024.pth' |
|
config = dict(model_type='StyleGAN2Generator', |
|
resolution=1024, |
|
w_dim=512, |
|
fmaps_base=int(1 * (32 << 10)), |
|
fmaps_max=512,) |
|
generator = build_model(**config) |
|
print(f'Loading checkpoint from `{checkpoint_path}` ...') |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')['models'] |
|
if 'generator_smooth' in checkpoint: |
|
generator.load_state_dict(checkpoint['generator_smooth']) |
|
else: |
|
generator.load_state_dict(checkpoint['generator']) |
|
generator = generator.eval().cpu() |
|
print('Finish loading checkpoint.') |
|
|
|
print('Loading boundaries') |
|
ATTRS = ['eyebrows', 'eyesize', 'gaze_direction', 'nose_length', 'mouth', 'lipstick'] |
|
boundaries = {} |
|
for attr in ATTRS: |
|
boundary_path = os.path.join(f'directions/ffhq/stylegan2/{attr}.npy') |
|
boundary = np.load(boundary_path) |
|
boundaries[attr] = boundary |
|
print('Generator and boundaries are ready.') |
|
|
|
|
|
def inference(num_of_image,seed,trunc_psi,eyebrows,eyesize,gaze_direction,nose_length,mouth,lipstick): |
|
print('Sampling latent codes with given seed.') |
|
num_of_image = num_of_image |
|
seed = seed |
|
trunc_psi = trunc_psi |
|
trunc_layers = 8 |
|
np.random.seed(seed) |
|
latent_z = np.random.randn(num_of_image, generator.z_dim) |
|
latent_z = torch.from_numpy(latent_z.astype(np.float32)) |
|
latent_z = latent_z.cpu() |
|
wp = generator.mapping(latent_z, None)['wp'] |
|
if trunc_psi < 1.0: |
|
w_avg = generator.w_avg |
|
w_avg = w_avg.reshape(1, -1, generator.w_dim)[:, :trunc_layers] |
|
wp[:, :trunc_layers] = w_avg.lerp(wp[:, :trunc_layers], trunc_psi) |
|
with torch.no_grad(): |
|
images_ori = generator.synthesis(wp)['image'] |
|
images_ori = postprocess_image(to_numpy(images_ori)) |
|
print('Original images are shown as belows.') |
|
imshow(images_ori, col=images_ori.shape[0]) |
|
latent_wp = to_numpy(wp) |
|
|
|
|
|
|
|
eyebrows = eyebrows |
|
eyesize = eyesize |
|
gaze_direction = gaze_direction |
|
nose_length = nose_length |
|
mouth = mouth |
|
lipstick = lipstick |
|
|
|
new_codes = latent_wp.copy() |
|
for attr_name in ATTRS: |
|
if attr_name in ['eyebrows', 'lipstick']: |
|
layers_idx = [8,9,10,11] |
|
else: |
|
layers_idx = [4,5,6,7] |
|
step = eval(attr_name) |
|
direction = boundaries[attr_name] |
|
direction = np.tile(direction, [1, generator.num_layers, 1]) |
|
new_codes[:, layers_idx, :] += direction[:, layers_idx, :] * step |
|
new_codes = torch.from_numpy(new_codes.astype(np.float32)).cpu() |
|
with torch.no_grad(): |
|
images_mani = generator.synthesis(new_codes)['image'] |
|
images_mani = postprocess_image(to_numpy(images_mani)) |
|
return imshow(images_mani, col=images_mani.shape[0]) |
|
|
|
gr.Interface(inference,[gr.Slider(1, 3, value=1,label="num_of_image"), |
|
gr.Slider(0, 10000, value=210,label="seed"), |
|
gr.Slider(0, 1, value=0.7,step=0.1,label="truncation psi"), |
|
gr.Slider(-12, 12, value=0,label="eyebrows"), |
|
gr.Slider(-12, 12, value=0,label="eyesize"), |
|
gr.Slider(-12, 12, value=0,label="gaze direction"), |
|
gr.Slider(-12, 12, value=0,label="nose_length"), |
|
gr.Slider(-12, 12, value=0,label="mouth"), |
|
gr.Slider(-12, 12, value=0,label="lipstick"), |
|
],gr.Image(type="pil")).launch() |