##import spaces import argparse import os import tempfile from functools import partial import cv2 import gradio as gr import imageio import numpy as np import torch import torchvision from omegaconf import OmegaConf from PIL import Image, ImageDraw from pytorch_lightning import seed_everything import sys import copy from utils.gradio_utils import load_preprocess_model, preprocess_image from ldm.util import instantiate_from_config, img2tensor from customnet.ddim import DDIMSampler from einops import rearrange import math if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" #### Description #### title = r"""

CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models

""" description = r""" Official Gradio demo for CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models.
🔥 CustomNet is novel unified customization method that can generate harmonious customized images without test-time optimization. CustomNet supports explicit viewpoint, location, text controls while ensuring object identity preservation.
🤗 Try to customize the object gneration yourself!
""" article = r""" If CustomNet is helpful, please help to ⭐ the Github Repo. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC%2FCustomNet)](https://github.com/TencentARC/CustomNet) --- 📝 **Citation**
If our work is useful for your research, please consider citing: ```bibtex @misc{yuan2023customnet, title={CustomNet: Zero-shot Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models}, author={Ziyang Yuan and Mingdeng Cao and Xintao Wang and Zhongang Qi and Chun Yuan and Ying Shan}, year={2023}, eprint={2310.19784}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` 📧 **Contact**
If you have any questions, please feel free to reach me out at yuanzy22@mails.tsinghua.edu.cn. """ # input_img = None # concat_img = None # T = None # prompt = None negtive_prompt = "" # load model device = torch.device(device) preprocess_model = load_preprocess_model() config = OmegaConf.load("configs/config_customnet.yaml") model = instantiate_from_config(config.model) model_path='./customnet_v1.pt' #model_path='./customnet_v1.pt?download=true' #if not os.path.exists(model_path): # os.system(f'wget https://huggingface.co/TencentARC/CustomNet/resolve/main/customnet_v1.pt?download=true -P .') ckpt = torch.load(model_path, map_location="cpu") model.load_state_dict(ckpt, strict=False) del ckpt model = model.to(device) sampler = None def send_input_to_concat(input_image): W, H = input_image.size # image_array[:, 0, :] = image_array[:, 0, :] draw = ImageDraw.Draw(input_image) draw.rectangle([(0,0),(H-1, W-1)], outline="red", width=8) return input_image def preprocess_input(preprocess_model, input_image): # global input_img processed_image = preprocess_image(preprocess_model, input_image) # input_img = (processed_image / 255.0).astype(np.float32) return processed_image # return processed_image, processed_image def adjust_location(x0, y0, x1, y1, input_image): x_0 = min(x0, x1) x_1 = max(x0, x1) y_0 = min(y0, y1) y_1 = max(y0, y1) print(x0, y0, x1, y1) print(x_0, y_0, x_1, y_1) new_size = (x_1-x_0, y_1-y_0) input_image = input_image.resize(new_size) img_array = np.array(input_image) white_background = np.zeros((256, 256, 3)) white_background[y0:y1, x0:x1, :] = img_array img_array = white_background.astype(np.uint8) concat_img = Image.fromarray(img_array) draw = ImageDraw.Draw(concat_img) draw.rectangle([(x0,y0),(x1,y1)], outline="red", width=5) return x_0, y_0, x_1, y_1, concat_img #@spaces.GPU def prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text): if input_image.size[0] != 256 or input_image.size[1] != 256: input_image = input_image.resize((256, 256)) input_image = np.array(input_image) img_cond = img2tensor(input_image, bgr2rgb=False, float32=True).unsqueeze(0) / 255. img_cond = img_cond*2-1 img_location = copy.deepcopy(img_cond) input_im_padding = torch.ones_like(img_location) x_0 = min(x0, x1) x_1 = max(x0, x1) y_0 = min(y0, y1) y_1 = max(y0, y1) print(x0, y0, x1, y1) print(x_0, y_0, x_1, y_1) img_location = torch.nn.functional.interpolate(img_location, (y_1-y_0, x_1-x_0), mode="bilinear") input_im_padding[:,:, y_0:y_1, x_0:x_1] = img_location img_location = input_im_padding T = torch.tensor([[math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), 0.0]]).unsqueeze(1) batch = { "image_cond": img_cond.to(device), "image_location": img_location.to(device), 'T': T.to(device), 'text': [text], } return batch #@spaces.GPU(enable_queue=True, duration=180) # def run_generation(sampler, model, device, input_image, x0, y0, x1, y1, polar, azimuth, text, seed): def run_generation(sampler, input_image, x0, y0, x1, y1, polar, azimuth, text, seed): seed_everything(seed) batch = prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text) # model = model.to(device) sampler = DDIMSampler(model, device=device) c = model.get_learned_conditioning(batch["image_cond"]) c = torch.cat([c, batch["T"]], dim=-1) c = model.cc_projection(c) ## condition cond = {} cond['c_concat'] = [model.encode_first_stage((batch["image_location"])).mode().detach()] cond['c_crossattn'] = [c] text_embedding = model.text_encoder(batch["text"]) cond["c_crossattn"].append(text_embedding) ## null-condition uc = {} neg_prompt = "" uc['c_concat'] = [torch.zeros(1, 4, 32, 32).to(c.device)] uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] uc_text_embedding = model.text_encoder([neg_prompt]) uc['c_crossattn'].append(uc_text_embedding) ## sample shape = [4, 32, 32] samples_latents, _ = sampler.sample( S=50, batch_size=1, shape=shape, verbose=False, unconditional_guidance_scale=999, # useless conditioning=cond, unconditional_conditioning=uc, cfg_type=0, cfg_scale_dict={"img": 0., "text":0., "all": 3.0 } ) x_samples = model.decode_first_stage(samples_latents) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0).cpu().numpy() x_samples = rearrange(255.0 *x_samples[0], 'c h w -> h w c').astype(np.uint8) output_image = Image.fromarray(x_samples) x0, y0, x1, y1, image_location = adjust_location(x0, y0, x1, y1, input_image) return x0, y0, x1, y1, image_location, output_image def load_example(input_image, x0, y0, x1, y1, polar, azimuth, prompt): # print("AAAA") # print(type(x0)) # print(type(polar)) return input_image, x0, y0, x1, y1, polar, azimuth, prompt @torch.no_grad() def main(args): # load demo demo = gr.Blocks() with demo: gr.Markdown(title) gr.Markdown(description) with gr.Row(): ## Left column with gr.Column(): ## step 1. gr.Markdown("## Step 1: Upload an object image and process", show_label=False) # with gr.Row(equal_height=True): input_image = gr.Image(type="pil", interactive=True, elem_id="input_image", elem_classes='image', visible=True) preprocess_botton = gr.Button(value="Need preprocess", visible=True) ## step 2. gr.Markdown("## Step 2: Set up different controls ", show_label=False, visible=True) gr.Markdown("### 1: Object Location", show_label=False, visible=True) with gr.Row(): with gr.Column(): with gr.Row(): x0 = gr.Slider(minimum=0, maximum=256, step=1, label="X_0", value=0, interactive=True, visible=True) y0 = gr.Slider(minimum=0, maximum=256, step=1, label="Y_0", value=0, interactive=True, visible=True) with gr.Row(): x1 = gr.Slider(minimum=0, maximum=256, step=1, label="X_1", value=256, interactive=True, visible=True) y1 = gr.Slider(minimum=0, maximum=256, step=1, label="Y_1", value=256, interactive=True, visible=True) location_botton = gr.Button(value="Update Location ", visible=True) location_image = gr.Image(type="pil", interactive=True, elem_id="location", elem_classes='image', visible=True) gr.Markdown("### 2: Object Viewpoint", show_label=False, visible=True) with gr.Row(): polar = gr.Slider(minimum=-30, maximum=30, step=-0.5, label="Polar Angle", value=0.0, visible=True) azimuth = gr.Slider(minimum=-60, maximum=60, step=-0.5, label="Azimuth angle", value=0.0, visible=True) gr.Markdown("### 3: Text", show_label=False, visible=True) prompt = gr.Textbox(value="on the seaside", label="Prompt", interactive=True, visible=True) ## step 3. gr.Markdown("## Step 3: Run Generation", show_label=False, visible=True) seed = gr.Number(value=1234, precision=0, interactive=True, label="Seed", visible=True) start = gr.Button(value="Run generation !", visible=True) examples_full = [ ["examples/0.jpg", 50, 50, 256, 256, 0, -30, "a backpack in the office"], ["examples/1.jpg", 20, 20, 256, 256, -25, -35, "a pair of shoes on dirt road"], ["examples/2.jpg", 0, 0, 256, 256, -15, -20, "a car on the beach"], ["examples/3.jpg", 0, 0, 256, 256, 0, 30, "in the jungle"], ["examples/4.jpg", 0, 0, 256, 256, 0, -30, "in the snow"], ["examples/5.jpg", 20, 20, 240, 240, 10, 20, "with mountain behind"], ] ## Right column with gr.Column(): gr.Markdown("## Generation Results", show_label=False, visible=True) output_image = gr.Image(type="pil", interactive=True, elem_id="output_image", elem_classes='image', visible=True) gr.Examples( examples=examples_full, # NOTE: elements must match inputs list! fn=load_example, inputs=[input_image, x0, y0, x1, y1, polar, azimuth, prompt], outputs=[input_image, x0, y0, x1, y1, polar, azimuth, prompt], cache_examples=False, run_on_click=True, ) gr.Markdown(article) ## function input_image.change(send_input_to_concat, inputs=input_image, outputs=location_image) preprocess_botton.click(partial(preprocess_input, preprocess_model), inputs=input_image, outputs=input_image) location_botton.click(adjust_location, inputs=[x0, y0, x1, y1, input_image], outputs=[x0, y0, x1, y1, location_image]) # start.click(partial(run_generation, sampler, model, device), start.click(partial(run_generation, sampler), inputs=[input_image, x0, y0, x1, y1, polar, azimuth, prompt, seed], outputs=[x0, y0, x1, y1, location_image, output_image]) # demo.launch(server_name='0.0.0.0', share=False, server_port=args.port) # demo.queue(concurrency_count=1, max_size=10) demo.queue().launch() if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=12345) args = parser.parse_args() main(args)