hysts's picture
hysts HF staff
Update
c2f8f5a
raw
history blame
8.51 kB
#!/usr/bin/env python
import random
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline
DESCRIPTION = """\
# Attend-and-Excite
This is a demo for [Attend-and-Excite](https://arxiv.org/abs/2301.13826).
Attend-and-Excite performs attention-based generative semantic guidance to mitigate subject neglect in Stable Diffusion.
Select a prompt and a set of indices matching the subjects you wish to strengthen (the `Check token indices` cell can help map between a word and its index).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "CompVis/stable-diffusion-v1-4"
ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(model_id)
ax_pipe.to(device)
sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
sd_pipe.to(device)
MAX_INFERENCE_STEPS = 100
MAX_SEED = np.iinfo(np.int32).max
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED) # noqa: S311
return seed
def get_token_table(prompt: str) -> list[tuple[int, str]]:
tokens = [ax_pipe.tokenizer.decode(t) for t in ax_pipe.tokenizer(prompt)["input_ids"]]
tokens = tokens[1:-1]
return list(enumerate(tokens, start=1))
@spaces.GPU
def run(
prompt: str,
indices_to_alter_str: str,
seed: int = 0,
apply_attend_and_excite: bool = True,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
scale_factor: int = 20,
thresholds: dict[int, float] | None = None,
max_iter_to_alter: int = 25,
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
) -> PIL.Image.Image:
if thresholds is None:
thresholds = {10: 0.5, 20: 0.8}
if num_inference_steps > MAX_INFERENCE_STEPS:
error_message = f"Number of steps cannot exceed {MAX_INFERENCE_STEPS}."
raise gr.Error(error_message)
generator = torch.Generator(device=device).manual_seed(seed)
if apply_attend_and_excite:
try:
token_indices = list(map(int, indices_to_alter_str.split(",")))
except Exception as e:
raise ValueError("Invalid token indices.") from e
out = ax_pipe(
prompt=prompt,
token_indices=token_indices,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
max_iter_to_alter=max_iter_to_alter,
thresholds=thresholds,
scale_factor=scale_factor,
)
else:
out = sd_pipe(
prompt=prompt,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
)
return out.images[0]
def process_example(
prompt: str,
indices_to_alter_str: str,
seed: int,
apply_attend_and_excite: bool,
) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
token_table = get_token_table(prompt)
result = run(
prompt=prompt,
indices_to_alter_str=indices_to_alter_str,
seed=seed,
apply_attend_and_excite=apply_attend_and_excite,
)
return token_table, result
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
prompt = gr.Text(
label="Prompt",
max_lines=1,
placeholder="A pod of dolphins leaping out of the water in an ocean with a ship on the background",
)
with gr.Accordion(label="Check token indices", open=False):
show_token_indices_button = gr.Button("Show token indices")
token_indices_table = gr.Dataframe(label="Token indices", headers=["Index", "Token"], col_count=2)
token_indices_str = gr.Text(
label="Token indices (a comma-separated list indices of the tokens you wish to alter)",
max_lines=1,
placeholder="4,16",
)
apply_attend_and_excite = gr.Checkbox(label="Apply Attend-and-Excite", value=True)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=MAX_INFERENCE_STEPS,
step=1,
value=50,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0,
maximum=50,
step=0.1,
value=7.5,
)
run_button = gr.Button("Generate")
with gr.Column():
result = gr.Image(label="Result")
with gr.Row():
examples = [
[
"A mouse and a red car",
"2,6",
2098,
True,
],
[
"A mouse and a red car",
"2,6",
2098,
False,
],
[
"A horse and a dog",
"2,5",
123,
True,
],
[
"A horse and a dog",
"2,5",
123,
False,
],
[
"A painting of an elephant with glasses",
"5,7",
123,
True,
],
[
"A painting of an elephant with glasses",
"5,7",
123,
False,
],
[
"A playful kitten chasing a butterfly in a wildflower meadow",
"3,6,10",
123,
True,
],
[
"A playful kitten chasing a butterfly in a wildflower meadow",
"3,6,10",
123,
False,
],
[
"A grizzly bear catching a salmon in a crystal clear river surrounded by a forest",
"2,6,15",
123,
True,
],
[
"A grizzly bear catching a salmon in a crystal clear river surrounded by a forest",
"2,6,15",
123,
False,
],
[
"A pod of dolphins leaping out of the water in an ocean with a ship on the background",
"4,16",
123,
True,
],
[
"A pod of dolphins leaping out of the water in an ocean with a ship on the background",
"4,16",
123,
False,
],
]
gr.Examples(
examples=examples,
inputs=[
prompt,
token_indices_str,
seed,
apply_attend_and_excite,
],
outputs=[
token_indices_table,
result,
],
fn=process_example,
examples_per_page=20,
)
show_token_indices_button.click(
fn=get_token_table,
inputs=prompt,
outputs=token_indices_table,
queue=False,
api_name="get-token-table",
)
gr.on(
triggers=[prompt.submit, token_indices_str.submit, run_button.click],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=get_token_table,
inputs=prompt,
outputs=token_indices_table,
queue=False,
api_name=False,
).then(
fn=run,
inputs=[
prompt,
token_indices_str,
seed,
apply_attend_and_excite,
num_inference_steps,
guidance_scale,
],
outputs=result,
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()