Spaces:
Sleeping
Sleeping
import streamlit as st | |
from io import BytesIO | |
from typing import Literal | |
from diffusers import StableDiffusionPipeline | |
import torch | |
import time | |
seed = 42 | |
generator = torch.manual_seed(seed) | |
NUM_ITERS_TO_RUN = 2 | |
NUM_INFERENCE_STEPS = 20 | |
NUM_IMAGES_PER_PROMPT = 1 | |
def text2image( | |
prompt: str, | |
repo_id: Literal[ | |
"dreamlike-art/dreamlike-photoreal-2.0", | |
"hakurei/waifu-diffusion", | |
"prompthero/openjourney", | |
"stabilityai/stable-diffusion-2-1", | |
"runwayml/stable-diffusion-v1-5", | |
"nota-ai/bk-sdm-small", | |
"CompVis/stable-diffusion-v1-4", | |
], | |
): | |
start = time.time() | |
if torch.cuda.is_available(): | |
print("Using GPU") | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
repo_id, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
).to("cuda") | |
else: | |
print("Using CPU") | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
repo_id, | |
torch_dtype=torch.float32, | |
use_safetensors=True, | |
) | |
for _ in range(NUM_ITERS_TO_RUN): | |
images = pipeline( | |
prompt, | |
num_inference_steps=NUM_INFERENCE_STEPS, | |
generator=generator, | |
num_images_per_prompt=NUM_IMAGES_PER_PROMPT, | |
).images | |
end = time.time() | |
return images[0], start, end | |
def app(): | |
st.header("Text-to-image Web App") | |
st.subheader("Powered by Hugging Face") | |
user_input = st.text_area( | |
"Enter your text prompt below and click the button to submit." | |
) | |
option = st.selectbox( | |
"Select model (in order of processing time)", | |
( | |
"nota-ai/bk-sdm-small", | |
"CompVis/stable-diffusion-v1-4", | |
"runwayml/stable-diffusion-v1-5", | |
"prompthero/openjourney", | |
"hakurei/waifu-diffusion", | |
"stabilityai/stable-diffusion-2-1", | |
"dreamlike-art/dreamlike-photoreal-2.0", | |
), | |
) | |
with st.form("my_form"): | |
submit = st.form_submit_button(label="Submit text prompt") | |
if submit: | |
with st.spinner(text="Generating image ... It may take up to 20 minutes."): | |
im, start, end = text2image(prompt=user_input, repo_id=option) | |
buf = BytesIO() | |
im.save(buf, format="PNG") | |
byte_im = buf.getvalue() | |
hours, rem = divmod(end - start, 3600) | |
minutes, seconds = divmod(rem, 60) | |
st.success( | |
"Processing time: {:0>2}:{:0>2}:{:05.2f}.".format( | |
int(hours), int(minutes), seconds | |
) | |
) | |
st.image(im) | |
st.download_button( | |
label="Click here to download", | |
data=byte_im, | |
file_name="generated_image.png", | |
mime="image/png", | |
) | |
if __name__ == "__main__": | |
app() |