File size: 2,253 Bytes
ef187eb 0cffd40 ef187eb 0cffd40 8b1e96d 0cffd40 8b1e96d 0cccf69 8b1e96d ec35e66 8b1e96d ef187eb 8b1e96d f286ae5 8b1e96d 5c6a083 19e461a 8b1e96d 0cffd40 8b1e96d 0cffd40 ef187eb 8b1e96d 0cffd40 8b1e96d 0cffd40 556fb50 8b1e96d 3eaeeea 8b1e96d 0cffd40 8b1e96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
import spaces
from PIL import Image
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "tianweiy/DMD2"
checkpoints = {
"1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
"4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
}
loaded = None
CSS = """
.gradio-container {
max-width: 690px !important;
}
"""
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
# Function
@spaces.GPU()
def generate_image(prompt, ckpt):
global loaded
print(prompt, ckpt)
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if loaded != num_inference_steps:
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
pipe.unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda"))
loaded = num_inference_steps
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
return results.images[0]
# Gradio Interface
with gr.Blocks(css=CSS) as demo:
gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center></p>")
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='DMD2 Generated Image')
prompt.submit(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
demo.queue().launch() |