# Authors: Hui Ren (rhfeiyang.github.io) import spaces import os import gradio as gr from diffusers import DiffusionPipeline import matplotlib.pyplot as plt import torch from PIL import Image device = "cuda" if torch.cuda.is_available() else "cpu" dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16 print(f"Using {device} device, dtype={dtype}") pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1", torch_dtype=dtype).to(device) from inference import get_lora_network, inference, get_validation_dataloader lora_map = { "None": "None", "Andre Derain (fauvism)": "andre-derain_subset1", "Vincent van Gogh (post impressionism)": "van_gogh_subset1", "Andy Warhol (pop art)": "andy_subset1", "Walter Battiss": "walter-battiss_subset2", "Camille Corot (realism)": "camille-corot_subset1", "Claude Monet (impressionism)": "monet_subset2", "Pablo Picasso (cubism)": "picasso_subset1", "Jackson Pollock": "jackson-pollock_subset1", "Gerhard Richter (abstract expressionism)": "gerhard-richter_subset1", "M.C. Escher": "m.c.-escher_subset1", "Albert Gleizes": "albert-gleizes_subset1", "Hokusai (ukiyo-e)": "katsushika-hokusai_subset1", "Wassily Kandinsky": "kandinsky_subset1", "Gustav Klimt (art nouveau)": "klimt_subset3", "Roy Lichtenstein": "roy-lichtenstein_subset1", "Henri Matisse (abstract expressionism)": "henri-matisse_subset1", "Joan Miro": "joan-miro_subset2", } @spaces.GPU def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0): adapter_path = lora_map[adapter_choice] if adapter_path not in [None, "None"]: adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt" style_prompt="sks art" else: style_prompt=None prompts = [prompt] infer_loader = get_validation_dataloader(prompts,num_workers=0) network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"] pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, height=512, width=512, scales=[adapter_scale], save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, start_noise=-1, show=False, style_prompt=style_prompt, no_load=True, from_scratch=True, device=device, weight_dtype=dtype)[0][adapter_scale][0] return pred_images @spaces.GPU def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5): style_prompt=None prompts = [prompt] infer_loader = get_validation_dataloader(prompts,num_workers=0) network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"] pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, height=512, width=512, scales=[0.0], save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, start_noise=-1, show=False, style_prompt=style_prompt, no_load=True, from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0] return pred_images @spaces.GPU def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800): style_prompt=None prompts = [prompt] # convert np to pil ref_image = [Image.fromarray(ref_image)] network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"] infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0) pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, height=512, width=512, scales=[0.0], save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True, from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0] return pred_images @spaces.GPU def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800): adapter_path = lora_map[adapter_choice] if adapter_path not in [None, "None"]: adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt" style_prompt="sks art" else: style_prompt=None prompts = [prompt] # convert np to pil ref_image = [Image.fromarray(ref_image)] network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"] infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0) pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, height=512, width=512, scales=[adapter_scale], save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True, from_scratch=False, device=device, weight_dtype=dtype)[0][adapter_scale][0] return pred_images block = gr.Blocks() # Direct infer with block: with gr.Group(): gr.Markdown(" # Art-Free Diffusion Demo") gr.Markdown("(More features in development...)") with gr.Row(): text = gr.Textbox( label="Enter your prompt (long and detailed would be better):", max_lines=10, placeholder="Enter your prompt (long and detailed would be better)", container=True, value="A blue bench situated in a park, surrounded by trees and leaves. The bench is positioned under a tree, providing shade and a peaceful atmosphere. There are several benches in the park, with one being closer to the foreground and the others further in the background. A person can be seen in the distance, possibly enjoying the park or taking a walk. The overall scene is serene and inviting, with the bench serving as a focal point in the park's landscape.", ) with gr.Tab('Generation'): with gr.Row(): with gr.Column(): # gr.Markdown("## Art-Free Generation") # gr.Markdown("Generate images from text prompts.") gallery_gen_ori = gr.Image( label="W/O Adapter", show_label=True, elem_id="gallery", height="auto" ) with gr.Column(): # gr.Markdown("## Art-Free Generation") # gr.Markdown("Generate images from text prompts.") gallery_gen_art = gr.Image( label="W/ Adapter", show_label=True, elem_id="gallery", height="auto" ) with gr.Row(): btn_gen_ori = gr.Button("Art-Free Generate", scale=1) btn_gen_art = gr.Button("Artistic Generate", scale=1) with gr.Tab('Stylization'): with gr.Row(): with gr.Column(): # gr.Markdown("## Art-Free Generation") # gr.Markdown("Generate images from text prompts.") gallery_stylization_ref = gr.Image( label="Ref Image", show_label=True, elem_id="gallery", height="auto", scale=1, value="data/003904765.jpg" ) with gr.Column(scale=2): with gr.Row(): with gr.Column(): # gr.Markdown("## Art-Free Generation") # gr.Markdown("Generate images from text prompts.") gallery_stylization_ori = gr.Image( label="W/O Adapter", show_label=True, elem_id="gallery", height="auto", scale=1, ) with gr.Column(): # gr.Markdown("## Art-Free Generation") # gr.Markdown("Generate images from text prompts.") gallery_stylization_art = gr.Image( label="W/ Adapter", show_label=True, elem_id="gallery", height="auto", scale=1, ) start_timestep = gr.Slider(label="Adapter Timestep", minimum=0, maximum=1000, value=800, step=1) with gr.Row(): btn_style_ori = gr.Button("Art-Free Stylize", scale=1) btn_style_art = gr.Button("Artistic Stylize", scale=1) with gr.Row(): # with gr.Column(): # samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1, scale=1) scale = gr.Slider( label="Guidance Scale", minimum=0, maximum=20, value=7.5, step=0.1 ) # with gr.Column(): adapter_choice = gr.Dropdown( label="Select Art Adapter", choices=[ "Andre Derain (fauvism)","Vincent van Gogh (post impressionism)","Andy Warhol (pop art)", "Camille Corot (realism)", "Claude Monet (impressionism)", "Pablo Picasso (cubism)", "Gerhard Richter (abstract expressionism)", "Hokusai (ukiyo-e)", "Gustav Klimt (art nouveau)", "Henri Matisse (abstract expressionism)", "Walter Battiss", "Jackson Pollock", "M.C. Escher", "Albert Gleizes", "Wassily Kandinsky", "Roy Lichtenstein", "Joan Miro" ], value="Andre Derain (fauvism)", scale=1 ) with gr.Row(): steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1) adapter_scale = gr.Slider(label="Adapter Scale", minimum=0, maximum=1.5, value=1., step=0.1, scale=1) with gr.Row(): seed = gr.Slider(label="Seed",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1) gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori) gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art) gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori) gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art) block.launch() # block.launch(sharing=True)