Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
from optimization.constants import ASSETS_DIR_NAME, RANKED_RESULTS_DIR | |
from utils.metrics_accumulator import MetricsAccumulator | |
from utils.video import save_video | |
from utils.fft_pytorch import HighFrequencyLoss | |
from numpy import random | |
from optimization.augmentations import ImageAugmentations | |
from PIL import Image | |
import torch | |
import torchvision | |
from torchvision import transforms | |
import torchvision.transforms.functional as F | |
from torchvision.transforms import functional as TF | |
from torch.nn.functional import mse_loss | |
from optimization.losses import range_loss, d_clip_loss | |
import lpips | |
import numpy as np | |
from CLIP import clip | |
from guided_diffusion.guided_diffusion.script_util import ( | |
create_model_and_diffusion, | |
model_and_diffusion_defaults, | |
create_classifier, | |
classifier_defaults, | |
) | |
from utils.visualization import show_tensor_image, show_editied_masked_image | |
from utils.change_place import change_place, find_bbox | |
import pdb | |
import cv2 | |
def create_classifier_ours(): | |
model = torchvision.models.resnet50() | |
ckpt = torch.load('checkpoints/DRA_resnet50.pth')['model_state_dict'] | |
model.load_state_dict({k.replace('module.','').replace('last_linear','fc'):v for k,v in ckpt.items()}) | |
model = torch.nn.Sequential(*[torch.nn.Upsample(size=(256,256)), model]) | |
return model | |
class ImageEditor: | |
def __init__(self, args) -> None: | |
self.args = args | |
os.makedirs(self.args.output_path, exist_ok=True) | |
self.ranked_results_path = Path(os.path.join(self.args.output_path, RANKED_RESULTS_DIR)) | |
os.makedirs(self.ranked_results_path, exist_ok=True) | |
if self.args.export_assets: | |
self.assets_path = Path(os.path.join(self.args.output_path, ASSETS_DIR_NAME)) | |
os.makedirs(self.assets_path, exist_ok=True) | |
if self.args.seed is not None: | |
torch.manual_seed(self.args.seed) | |
np.random.seed(self.args.seed) | |
random.seed(self.args.seed) | |
self.model_config = model_and_diffusion_defaults() | |
self.model_config.update( | |
{ | |
"attention_resolutions": "32, 16, 8", | |
"class_cond": self.args.model_output_size == 512, | |
"diffusion_steps": 1000, | |
"rescale_timesteps": True, | |
"timestep_respacing": self.args.timestep_respacing, | |
"image_size": self.args.model_output_size, | |
"learn_sigma": True, | |
"noise_schedule": "linear", | |
"num_channels": 256, | |
"num_head_channels": 64, | |
"num_res_blocks": 2, | |
"resblock_updown": True, | |
"use_fp16": True, | |
"use_scale_shift_norm": True, | |
} | |
) | |
self.classifier_config = classifier_defaults() | |
self.classifier_config.update( | |
{ | |
"image_size": self.args.model_output_size, | |
} | |
) | |
# Load models | |
self.device = torch.device( | |
f"cuda:{self.args.gpu_id}" if torch.cuda.is_available() else "cpu" | |
) | |
print("Using device:", self.device) | |
self.model, self.diffusion = create_model_and_diffusion(**self.model_config) | |
self.model.load_state_dict( | |
torch.load( | |
"checkpoints/256x256_diffusion_uncond.pt" | |
if self.args.model_output_size == 256 | |
else "checkpoints/512x512_diffusion.pt", | |
map_location="cpu", | |
) | |
) | |
# self.model.requires_grad_(False).eval().to(self.device) | |
self.model.eval().to(self.device) | |
for name, param in self.model.named_parameters(): | |
if "qkv" in name or "norm" in name or "proj" in name: | |
param.requires_grad_() | |
if self.model_config["use_fp16"]: | |
self.model.convert_to_fp16() | |
self.classifier = create_classifier(**self.classifier_config) | |
self.classifier.load_state_dict( | |
torch.load("checkpoints/256x256_classifier.pt", map_location="cpu") | |
) | |
# self.classifier.requires_grad_(False).eval().to(self.device) | |
# self.classifier = create_classifier_ours() | |
self.classifier.eval().to(self.device) | |
if self.classifier_config["classifier_use_fp16"]: | |
self.classifier.convert_to_fp16() | |
self.clip_model = ( | |
clip.load("ViT-B/16", device=self.device, jit=False)[0].eval().requires_grad_(False) | |
) | |
self.clip_size = self.clip_model.visual.input_resolution | |
self.clip_normalize = transforms.Normalize( | |
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] | |
) | |
self.to_tensor = transforms.ToTensor() | |
self.lpips_model = lpips.LPIPS(net="vgg").to(self.device) | |
self.image_augmentations = ImageAugmentations(self.clip_size, self.args.aug_num) | |
self.metrics_accumulator = MetricsAccumulator() | |
self.hf_loss = HighFrequencyLoss() | |
def unscale_timestep(self, t): | |
unscaled_timestep = (t * (self.diffusion.num_timesteps / 1000)).long() | |
return unscaled_timestep | |
def clip_loss(self, x_in, text_embed): | |
clip_loss = torch.tensor(0) | |
if self.mask is not None: | |
masked_input = x_in * self.mask | |
else: | |
masked_input = x_in | |
augmented_input = self.image_augmentations(masked_input).add(1).div(2) # shape: [N,C,H,W], range: [0,1] | |
clip_in = self.clip_normalize(augmented_input) | |
# pdb.set_trace() | |
image_embeds = self.clip_model.encode_image(clip_in).float() | |
dists = d_clip_loss(image_embeds, text_embed) | |
# We want to sum over the averages | |
for i in range(self.args.batch_size): | |
# We want to average at the "augmentations level" | |
clip_loss = clip_loss + dists[i :: self.args.batch_size].mean() | |
return clip_loss | |
def unaugmented_clip_distance(self, x, text_embed): | |
x = F.resize(x, [self.clip_size, self.clip_size]) | |
image_embeds = self.clip_model.encode_image(x).float() | |
dists = d_clip_loss(image_embeds, text_embed) | |
return dists.item() | |
def model_fn(self, x,t,y=None): | |
return self.model(x, t, y if self.args.class_cond else None) | |
def edit_image_by_prompt(self): | |
if self.args.image_guide: | |
img_guidance = Image.open(self.args.prompt).convert('RGB') | |
img_guidance = img_guidance.resize((224,224), Image.LANCZOS) # type: ignore | |
img_guidance = self.clip_normalize(self.to_tensor(img_guidance).unsqueeze(0)).to(self.device) | |
text_embed = self.clip_model.encode_image(img_guidance).float() | |
else: | |
text_embed = self.clip_model.encode_text( | |
clip.tokenize(self.args.prompt).to(self.device) | |
).float() | |
self.image_size = (self.model_config["image_size"], self.model_config["image_size"]) | |
self.init_image_pil = Image.open(self.args.init_image).convert("RGB") | |
self.init_image_pil = self.init_image_pil.resize(self.image_size, Image.LANCZOS) # type: ignore | |
self.init_image = ( | |
TF.to_tensor(self.init_image_pil).to(self.device).unsqueeze(0).mul(2).sub(1) | |
) | |
self.init_image_pil_2 = Image.open(self.args.init_image_2).convert("RGB") | |
if self.args.rotate_obj: | |
# angle = random.randint(-45,45) | |
angle = self.args.angle | |
self.init_image_pil_2 = self.init_image_pil_2.rotate(angle) | |
self.init_image_pil_2 = self.init_image_pil_2.resize(self.image_size, Image.LANCZOS) # type: ignore | |
self.init_image_2 = ( | |
TF.to_tensor(self.init_image_pil_2).to(self.device).unsqueeze(0).mul(2).sub(1) | |
) | |
''' | |
# Init with the inpainting image | |
self.init_image_pil_ = Image.open('output/ImageNet-S_val/bad_case_RN50/ILSVRC2012_val_00013212/ranked/08480_output_i_0_b_0.png').convert("RGB") | |
self.init_image_pil_ = self.init_image_pil_.resize(self.image_size, Image.LANCZOS) # type: ignore | |
self.init_image_ = ( | |
TF.to_tensor(self.init_image_pil_).to(self.device).unsqueeze(0).mul(2).sub(1) | |
) | |
''' | |
if self.args.export_assets: | |
img_path = self.assets_path / Path(self.args.output_file) | |
self.init_image_pil.save(img_path, quality=100) | |
self.mask = torch.ones_like(self.init_image, device=self.device) | |
self.mask_pil = None | |
if self.args.mask is not None: | |
self.mask_pil = Image.open(self.args.mask).convert("RGB") | |
if self.args.rotate_obj: | |
self.mask_pil = self.mask_pil.rotate(angle) | |
if self.mask_pil.size != self.image_size: | |
self.mask_pil = self.mask_pil.resize(self.image_size, Image.NEAREST) # type: ignore | |
if self.args.random_position: | |
bbox = find_bbox(np.array(self.mask_pil)) | |
print(bbox) | |
image_mask_pil_binarized = ((np.array(self.mask_pil) > 0.5) * 255).astype(np.uint8) | |
# image_mask_pil_binarized = cv2.dilate(image_mask_pil_binarized, np.ones((50,50), np.uint8), iterations=1) | |
if self.args.invert_mask: | |
image_mask_pil_binarized = 255 - image_mask_pil_binarized | |
self.mask_pil = TF.to_pil_image(image_mask_pil_binarized) | |
self.mask = TF.to_tensor(Image.fromarray(image_mask_pil_binarized)) | |
self.mask = self.mask[0, ...].unsqueeze(0).unsqueeze(0).to(self.device) | |
# self.mask[:] = 1 | |
if self.args.random_position: | |
# print(self.init_image_2.shape, self.init_image_2.max(), self.init_image_2.min()) | |
# print(self.mask.shape, self.mask.max(), self.mask.min()) | |
# cv2.imwrite('tmp/init_before.jpg', np.transpose(((self.init_image_2+1)/2*255).cpu().numpy()[0], (1,2,0))[:,:,::-1]) | |
# cv2.imwrite('tmp/mask_before.jpg', (self.mask*255).cpu().numpy()[0][0]) | |
self.init_image_2, self.mask = change_place(self.init_image_2, self.mask, bbox, self.args.invert_mask) | |
# cv2.imwrite('tmp/init_after.jpg', np.transpose(((self.init_image_2+1)/2*255).cpu().numpy()[0], (1,2,0))[:,:,::-1]) | |
# cv2.imwrite('tmp/mask_after.jpg', (self.mask*255).cpu().numpy()[0][0]) | |
if self.args.export_assets: | |
mask_path = self.assets_path / Path( | |
self.args.output_file.replace(".png", "_mask.png") | |
) | |
self.mask_pil.save(mask_path, quality=100) | |
def class_guided(x, y, t): | |
assert y is not None | |
with torch.enable_grad(): | |
x_in = x.detach().requires_grad_(True) | |
# logits = self.classifier(x_in, t) | |
logits = self.classifier(x_in) | |
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) | |
selected = log_probs[range(len(logits)), y.view(-1)] | |
loss = selected.sum() | |
return -torch.autograd.grad(loss, x_in)[0] * self.args.classifier_scale | |
def cond_fn(x, t, y=None): | |
if self.args.prompt == "": | |
return torch.zeros_like(x) | |
# pdb.set_trace() | |
with torch.enable_grad(): | |
x = x.detach().requires_grad_() | |
t_unscale = self.unscale_timestep(t) | |
''' | |
out = self.diffusion.p_mean_variance( | |
self.model, x, t, clip_denoised=False, model_kwargs={"y": y} | |
) | |
''' | |
out = self.diffusion.p_mean_variance( | |
self.model, x, t_unscale, clip_denoised=False, model_kwargs={"y": None} | |
) | |
fac = self.diffusion.sqrt_one_minus_alphas_cumprod[t_unscale[0].item()] | |
# x_in = out["pred_xstart"] * fac + x * (1 - fac) | |
x_in = out["pred_xstart"] # Revised by XX, 2022.07.14 | |
loss = torch.tensor(0) | |
if self.args.classifier_scale != 0 and y is not None: | |
# gradient_class_guided = class_guided(x, y, t) | |
gradient_class_guided = class_guided(x_in, y, t) | |
if self.args.background_complex != 0: | |
if self.args.hard: | |
loss = loss - self.args.background_complex*self.hf_loss((x_in+1.)/2.) | |
else: | |
loss = loss + self.args.background_complex*self.hf_loss((x_in+1.)/2.) | |
if self.args.clip_guidance_lambda != 0: | |
clip_loss = self.clip_loss(x_in, text_embed) * self.args.clip_guidance_lambda | |
loss = loss + clip_loss | |
self.metrics_accumulator.update_metric("clip_loss", clip_loss.item()) | |
if self.args.range_lambda != 0: | |
r_loss = range_loss(out["pred_xstart"]).sum() * self.args.range_lambda | |
loss = loss + r_loss | |
self.metrics_accumulator.update_metric("range_loss", r_loss.item()) | |
if self.args.background_preservation_loss: | |
x_in = out["pred_xstart"] * fac + x * (1 - fac) | |
if self.mask is not None: | |
# masked_background = x_in * (1 - self.mask) | |
masked_background = x_in * self.mask # 2022.07.19 | |
else: | |
masked_background = x_in | |
if self.args.lpips_sim_lambda: | |
''' | |
loss = ( | |
loss | |
+ self.lpips_model(masked_background, self.init_image).sum() | |
* self.args.lpips_sim_lambda | |
) | |
''' | |
# 2022.07.19 | |
loss = ( | |
loss | |
+ self.lpips_model(masked_background, self.init_image*self.mask).sum() | |
* self.args.lpips_sim_lambda | |
) | |
if self.args.l2_sim_lambda: | |
''' | |
loss = ( | |
loss | |
+ mse_loss(masked_background, self.init_image) * self.args.l2_sim_lambda | |
) | |
''' | |
# 2022.07.19 | |
loss = ( | |
loss | |
+ mse_loss(masked_background, self.init_image*self.mask) * self.args.l2_sim_lambda | |
) | |
if self.args.classifier_scale != 0 and y is not None: | |
return -torch.autograd.grad(loss, x)[0] + gradient_class_guided | |
else: | |
return -torch.autograd.grad(loss, x)[0] | |
def postprocess_fn(out, t): | |
if self.args.coarse_to_fine: | |
if t > 50: | |
kernel = 51 | |
elif t > 35: | |
kernel = 31 | |
else: | |
kernel = 0 | |
if kernel > 0: | |
max_pool = torch.nn.MaxPool2d(kernel_size=kernel, stride=1, padding=int((kernel-1)/2)) | |
self.mask_d = 1 - self.mask | |
self.mask_d = max_pool(self.mask_d) | |
self.mask_d = 1 - self.mask_d | |
else: | |
self.mask_d = self.mask | |
else: | |
self.mask_d = self.mask | |
if self.mask is not None: | |
background_stage_t = self.diffusion.q_sample(self.init_image_2, t[0]) | |
background_stage_t = torch.tile( | |
background_stage_t, dims=(self.args.batch_size, 1, 1, 1) | |
) | |
out["sample"] = out["sample"] * self.mask_d + background_stage_t * (1 - self.mask_d) | |
return out | |
save_image_interval = self.diffusion.num_timesteps // 5 | |
for iteration_number in range(self.args.iterations_num): | |
print(f"Start iterations {iteration_number}") | |
sample_func = ( | |
self.diffusion.ddim_sample_loop_progressive | |
if self.args.ddim | |
else self.diffusion.p_sample_loop_progressive | |
) | |
samples = sample_func( | |
self.model_fn, | |
( | |
self.args.batch_size, | |
3, | |
self.model_config["image_size"], | |
self.model_config["image_size"], | |
), | |
clip_denoised=False, | |
# model_kwargs={} | |
# if self.args.model_output_size == 256 | |
# else { | |
# "y": torch.zeros([self.args.batch_size], device=self.device, dtype=torch.long) | |
# }, | |
model_kwargs={} | |
if self.args.classifier_scale == 0 | |
else {"y": self.args.y*torch.ones([self.args.batch_size], device=self.device, dtype=torch.long)}, | |
cond_fn=cond_fn, | |
device=self.device, | |
progress=True, | |
skip_timesteps=self.args.skip_timesteps, | |
init_image=self.init_image, | |
# init_image=self.init_image_, | |
postprocess_fn=None if self.args.local_clip_guided_diffusion else postprocess_fn, | |
randomize_class=True if self.args.classifier_scale == 0 else False, | |
) | |
intermediate_samples = [[] for i in range(self.args.batch_size)] | |
total_steps = self.diffusion.num_timesteps - self.args.skip_timesteps - 1 | |
for j, sample in enumerate(samples): | |
should_save_image = j % save_image_interval == 0 or j == total_steps | |
if should_save_image or self.args.save_video: | |
self.metrics_accumulator.print_average_metric() | |
for b in range(self.args.batch_size): | |
pred_image = sample["pred_xstart"][b] | |
visualization_path = Path( | |
os.path.join(self.args.output_path, self.args.output_file) | |
) | |
visualization_path = visualization_path.with_stem( | |
f"{visualization_path.stem}_i_{iteration_number}_b_{b}" | |
) | |
if ( | |
self.mask is not None | |
and self.args.enforce_background | |
and j == total_steps | |
and not self.args.local_clip_guided_diffusion | |
): | |
pred_image = ( | |
self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0] | |
) | |
''' | |
if j == total_steps: | |
pdb.set_trace() | |
pred_image = ( | |
self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0] | |
) | |
''' | |
pred_image = pred_image.add(1).div(2).clamp(0, 1) | |
pred_image_pil = TF.to_pil_image(pred_image) | |
masked_pred_image = self.mask * pred_image.unsqueeze(0) | |
final_distance = self.unaugmented_clip_distance( | |
masked_pred_image, text_embed | |
) | |
formatted_distance = f"{final_distance:.4f}" | |
if self.args.export_assets: | |
pred_path = self.assets_path / visualization_path.name | |
pred_image_pil.save(pred_path, quality=100) | |
if j == total_steps: | |
path_friendly_distance = formatted_distance.replace(".", "") | |
ranked_pred_path = self.ranked_results_path / ( | |
path_friendly_distance + "_" + visualization_path.name | |
) | |
pred_image_pil.save(ranked_pred_path, quality=100) | |
intermediate_samples[b].append(pred_image_pil) | |
if should_save_image: | |
show_editied_masked_image( | |
title=self.args.prompt, | |
source_image=self.init_image_pil, | |
edited_image=pred_image_pil, | |
mask=self.mask_pil, | |
path=visualization_path, | |
distance=formatted_distance, | |
) | |
if self.args.save_video: | |
for b in range(self.args.batch_size): | |
video_name = self.args.output_file.replace( | |
".png", f"_i_{iteration_number}_b_{b}.avi" | |
) | |
video_path = os.path.join(self.args.output_path, video_name) | |
save_video(intermediate_samples[b], video_path) | |
visualize_size = (256,256) | |
img_ori = cv2.imread(self.args.init_image_2) | |
img_ori = cv2.resize(img_ori, visualize_size) | |
mask = cv2.imread(self.args.mask) | |
mask = cv2.resize(mask, visualize_size) | |
imgs = [img_ori, mask] | |
for ii, img_name in enumerate(os.listdir(os.path.join(self.args.output_path, 'ranked'))): | |
img_path = os.path.join(self.args.output_path, 'ranked', img_name) | |
img = cv2.imread(img_path) | |
img = cv2.resize(img, visualize_size) | |
imgs.append(img) | |
if ii >= 7: | |
break | |
img_whole = cv2.hconcat(imgs[2:]) | |
''' | |
img_name = self.args.output_path.split('/')[-2]+'/' | |
if self.args.coarse_to_fine: | |
if self.args.clip_guidance_lambda == 0: | |
prompt = 'coarse_to_fine_no_clip' | |
else: | |
prompt = 'coarse_to_fine' | |
elif self.args.image_guide: | |
prompt = 'image_guide' | |
elif self.args.clip_guidance_lambda == 0: | |
prompt = 'no_clip_guide' | |
else: | |
prompt = 'text_guide' | |
''' | |
cv2.imwrite(os.path.join(self.args.final_save_root, 'edited.png'), img_whole, [int(cv2.IMWRITE_PNG_COMPRESSION), 0]) | |
def reconstruct_image(self): | |
init = Image.open(self.args.init_image).convert("RGB") | |
init = init.resize( | |
self.image_size, # type: ignore | |
Image.LANCZOS, | |
) | |
init = TF.to_tensor(init).to(self.device).unsqueeze(0).mul(2).sub(1) | |
samples = self.diffusion.p_sample_loop_progressive( | |
self.model, | |
(1, 3, self.model_config["image_size"], self.model_config["image_size"],), | |
clip_denoised=False, | |
model_kwargs={} | |
if self.args.model_output_size == 256 | |
else {"y": torch.zeros([self.args.batch_size], device=self.device, dtype=torch.long)}, | |
cond_fn=None, | |
progress=True, | |
skip_timesteps=self.args.skip_timesteps, | |
init_image=init, | |
randomize_class=True, | |
) | |
save_image_interval = self.diffusion.num_timesteps // 5 | |
max_iterations = self.diffusion.num_timesteps - self.args.skip_timesteps - 1 | |
for j, sample in enumerate(samples): | |
if j % save_image_interval == 0 or j == max_iterations: | |
print() | |
filename = os.path.join(self.args.output_path, self.args.output_file) | |
TF.to_pil_image(sample["pred_xstart"][0].add(1).div(2).clamp(0, 1)).save(filename) | |