huguru / app.py
echarlaix's picture
echarlaix HF staff
Statically reshapes the SD modelto speed up inference
9659c11
raw
history blame
4.04 kB
import gradio as gr
from share_btn import community_icon_html, loading_icon_html, share_js
import random
import re
import torch
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, set_seed
from optimum.intel.openvino import OVStableDiffusionPipeline
horoscope_model_id = "shahp7575/gpt2-horoscopes"
tokenizer = AutoTokenizer.from_pretrained(horoscope_model_id)
model = AutoModelWithLMHead.from_pretrained(horoscope_model_id)
text_generation_pipe = pipeline("text-generation", model="Gustavosta/MagicPrompt-Stable-Diffusion", tokenizer="gpt2")
stable_diffusion_pipe = OVStableDiffusionPipeline.from_pretrained("echarlaix/stable-diffusion-v1-5-openvino", revision="fp16", compile=False)
height = 128
width = 128
stable_diffusion_pipe.reshape(batch_size=1, height=height, width=width, num_images_per_prompt=1)
stable_diffusion_pipe.compile()
def fn(sign, cat):
prompt = f"<|category|> {cat} <|horoscope|> {sign}"
prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
sample_outputs = model.generate(
prompt_encoded,
do_sample=True,
top_k=40,
max_length=300,
top_p=0.95,
temperature=0.95,
num_beams=4,
num_return_sequences=1,
)
final_out = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
starting_text = " ".join(final_out.split(" ")[4:])
seed = random.randint(100, 1000000)
set_seed(seed)
response = text_generation_pipe(starting_text + " " + sign + " art", max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=1)
image = stable_diffusion_pipe(response[0]["generated_text"], height=height, width=width, num_inference_steps=30).images[0]
return [image, starting_text]
block = gr.Blocks(css="./css.css")
with block:
with gr.Group():
with gr.Box():
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
text = gr.Dropdown(
label="Star Sign",
choices=["Aries", "Taurus","Gemini", "Cancer", "Leo", "Virgo", "Libra", "Scorpio", "Sagittarius", "Capricorn", "Aquarius", "Pisces"],
show_label=True,
max_lines=1,
placeholder="Enter your prompt",
elem_id="prompt-text-input",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
text2 = gr.Dropdown(
choices=["Love", "Career", "Wellness"],
label="Category",
show_label=True,
max_lines=1,
placeholder="Enter your prompt",
elem_id="prompt-text-input",
).style(
border=(True, True, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
gallery = gr.Image(
interactive=False,
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
text = gr.Textbox("Text")
with gr.Group(elem_id="container-advanced-btns"):
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
btn.click(fn=fn, inputs=[text, text2], outputs=[gallery, text])
share_button.click(
None,
[],
[],
_js=share_js,
)
block.queue(concurrency_count=40, max_size=20).launch(max_threads=150)