Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from PIL import Image | |
import spaces | |
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline | |
device = "cuda" | |
num_images_per_prompt = 1 | |
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device) | |
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device) | |
css = """ | |
footer { | |
visibility: hidden | |
} | |
#generate_button { | |
color: white; | |
border-color: #007bff; | |
background: #2563eb; | |
} | |
#save_button { | |
color: white; | |
border-color: #028b40; | |
background: #01b97c; | |
width: 200px; | |
} | |
#settings_header { | |
background: rgb(245, 105, 105); | |
} | |
""" | |
def gen(prompt, negative, width, height): | |
prior_output = prior( | |
prompt=prompt, | |
height=height, | |
width=width, | |
negative_prompt=negative, | |
guidance_scale=4.0, | |
num_images_per_prompt=num_images_per_prompt, | |
num_inference_steps=20 | |
) | |
decoder_output = decoder( | |
image_embeddings=prior_output.image_embeddings.half(), | |
prompt=prompt, | |
negative_prompt=negative, | |
guidance_scale=0.0, | |
output_type="pil", | |
num_inference_steps=10 | |
).images | |
return decoder_output | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# Stable Cascade ```DEMO```") | |
with gr.Row(): | |
prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=3, lines=1, interactive=True, scale=20) | |
button = gr.Button(value="Generate", scale=1) | |
with gr.Accordion("Advanced options", open=False): | |
with gr.Row(): | |
negative = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=2, lines=1, interactive=True) | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=8, value=1024, interactive=True) | |
height = gr.Slider(label="Height", minimum=1024, maximum=2048, step=8, value=1024, interactive=True) | |
with gr.Row() | |
gallery = gr.Gallery(show_label=False, rows=1, columns=1) | |
button.click(gen, inputs=[prompt, negative, width, height], outputs=gallery) | |
demo.launch(show_api=False) | |