Art-Free-Diffusion / hf_demo.py
rhfeiyang's picture
update
fb8d464
raw
history blame
11.2 kB
# Authors: Hui Ren (rhfeiyang.github.io)
import spaces
import os
import gradio as gr
from diffusers import DiffusionPipeline
import matplotlib.pyplot as plt
import torch
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
print(f"Using {device} device, dtype={dtype}")
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
torch_dtype=dtype).to(device)
from inference import get_lora_network, inference, get_validation_dataloader
lora_map = {
"None": "None",
"Andre Derain (fauvism)": "andre-derain_subset1",
"Vincent van Gogh (post impressionism)": "van_gogh_subset1",
"Andy Warhol (pop art)": "andy_subset1",
"Walter Battiss": "walter-battiss_subset2",
"Camille Corot (realism)": "camille-corot_subset1",
"Claude Monet (impressionism)": "monet_subset2",
"Pablo Picasso (cubism)": "picasso_subset1",
"Jackson Pollock": "jackson-pollock_subset1",
"Gerhard Richter (abstract expressionism)": "gerhard-richter_subset1",
"M.C. Escher": "m.c.-escher_subset1",
"Albert Gleizes": "albert-gleizes_subset1",
"Hokusai (ukiyo-e)": "katsushika-hokusai_subset1",
"Wassily Kandinsky": "kandinsky_subset1",
"Gustav Klimt (art nouveau)": "klimt_subset3",
"Roy Lichtenstein": "roy-lichtenstein_subset1",
"Henri Matisse (abstract expressionism)": "henri-matisse_subset1",
"Joan Miro": "joan-miro_subset2",
}
@spaces.GPU
def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0):
adapter_path = lora_map[adapter_choice]
if adapter_path not in [None, "None"]:
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
style_prompt="sks art"
else:
style_prompt=None
prompts = [prompt]
infer_loader = get_validation_dataloader(prompts,num_workers=0)
network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
height=512, width=512, scales=[adapter_scale],
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
from_scratch=True, device=device, weight_dtype=dtype)[0][1.0][0]
return pred_images
@spaces.GPU
def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5):
style_prompt=None
prompts = [prompt]
infer_loader = get_validation_dataloader(prompts,num_workers=0)
network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"]
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
height=512, width=512, scales=[0.0],
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0]
return pred_images
@spaces.GPU
def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800):
style_prompt=None
prompts = [prompt]
# convert np to pil
ref_image = [Image.fromarray(ref_image)]
network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"]
infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
height=512, width=512, scales=[0.0],
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0]
return pred_images
@spaces.GPU
def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):
adapter_path = lora_map[adapter_choice]
if adapter_path not in [None, "None"]:
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
style_prompt="sks art"
else:
style_prompt=None
prompts = [prompt]
# convert np to pil
ref_image = [Image.fromarray(ref_image)]
network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
height=512, width=512, scales=[adapter_scale],
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
from_scratch=False, device=device, weight_dtype=dtype)[0][1.0][0]
return pred_images
block = gr.Blocks()
# Direct infer
with block:
with gr.Group():
gr.Markdown(" # Art-Free Diffusion Demo")
gr.Markdown("(More features in development...)")
with gr.Row():
text = gr.Textbox(
label="Enter your prompt(long and detailed would be better):",
max_lines=2,
placeholder="Enter your prompt(long and detailed would be better)",
container=True,
value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
)
with gr.Tab('Generation'):
with gr.Row():
with gr.Column():
# gr.Markdown("## Art-Free Generation")
# gr.Markdown("Generate images from text prompts.")
gallery_gen_ori = gr.Image(
label="W/O Adapter",
show_label=True,
elem_id="gallery",
height="auto"
)
with gr.Column():
# gr.Markdown("## Art-Free Generation")
# gr.Markdown("Generate images from text prompts.")
gallery_gen_art = gr.Image(
label="W/ Adapter",
show_label=True,
elem_id="gallery",
height="auto"
)
with gr.Row():
btn_gen_ori = gr.Button("Art-Free Generate", scale=1)
btn_gen_art = gr.Button("Artistic Generate", scale=1)
with gr.Tab('Stylization'):
with gr.Row():
with gr.Column():
# gr.Markdown("## Art-Free Generation")
# gr.Markdown("Generate images from text prompts.")
gallery_stylization_ref = gr.Image(
label="Ref Image",
show_label=True,
elem_id="gallery",
height="auto",
scale=1,
)
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
# gr.Markdown("## Art-Free Generation")
# gr.Markdown("Generate images from text prompts.")
gallery_stylization_ori = gr.Image(
label="W/O Adapter",
show_label=True,
elem_id="gallery",
height="auto",
scale=1,
)
with gr.Column():
# gr.Markdown("## Art-Free Generation")
# gr.Markdown("Generate images from text prompts.")
gallery_stylization_art = gr.Image(
label="W/ Adapter",
show_label=True,
elem_id="gallery",
height="auto",
scale=1,
)
start_timestep = gr.Slider(label="Adapter Timestep", minimum=0, maximum=1000, value=800, step=1)
with gr.Row():
btn_style_ori = gr.Button("Art-Free Stylization", scale=1)
btn_style_art = gr.Button("Artistic Stylization", scale=1)
with gr.Row():
# with gr.Column():
# samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1, scale=1)
scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=20, value=7.5, step=0.1
)
# with gr.Column():
adapter_choice = gr.Dropdown(
label="Select Art Adapter",
choices=[ "Andre Derain (fauvism)","Vincent van Gogh (post impressionism)","Andy Warhol (pop art)",
"Camille Corot (realism)", "Claude Monet (impressionism)", "Pablo Picasso (cubism)", "Gerhard Richter (abstract expressionism)",
"Hokusai (ukiyo-e)", "Gustav Klimt (art nouveau)", "Henri Matisse (abstract expressionism)",
"Walter Battiss", "Jackson Pollock", "M.C. Escher", "Albert Gleizes", "Wassily Kandinsky",
"Roy Lichtenstein", "Joan Miro"
],
value="Andre Derain (fauvism)",
scale=1
)
with gr.Row():
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
adapter_scale = gr.Slider(label="Stylization Scale", minimum=0, maximum=1.5, value=1., step=0.1, scale=1)
with gr.Row():
seed = gr.Slider(label="Seed",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1)
gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori)
gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art)
gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori)
gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art)
block.launch(sharing=True)