Spaces:
Runtime error
Runtime error
File size: 8,421 Bytes
d961e81 a25009a 05a002e a25009a 5780ef2 b4cfcd3 5780ef2 b4cfcd3 5780ef2 e4e7ff0 5780ef2 b4cfcd3 5780ef2 b4cfcd3 1dda469 5780ef2 b4cfcd3 5780ef2 b4cfcd3 5780ef2 b4cfcd3 5780ef2 b4cfcd3 5780ef2 b4cfcd3 5780ef2 b4cfcd3 77dcb47 1dda469 b4cfcd3 5780ef2 b4cfcd3 77dcb47 b4cfcd3 5780ef2 b4cfcd3 5780ef2 b4cfcd3 5780ef2 b4cfcd3 5780ef2 b4cfcd3 5780ef2 b4cfcd3 5780ef2 b4cfcd3 1dda469 b4cfcd3 5780ef2 b4cfcd3 5d11034 b4cfcd3 f74814e b4cfcd3 1dda469 b4cfcd3 e4b2dd1 b4cfcd3 31ccd79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
from huggingface_hub import HfApi
import os
api = HfApi(token=os.environ['HUGGING_FACE_HUB_TOKEN'])
import math
import time
from threading import Lock
from typing import Any, List
import argparse
import numpy as np
from diffusers import StableDiffusionPipeline
from matplotlib import pyplot as plt
import gradio as gr
import torch
from spacy import displacy
from daam import trace
from daam.utils import set_seed, cached_nlp, auto_autocast
def dependency(text):
doc = cached_nlp(text)
svg = displacy.render(doc, style='dep', options={'compact': True, 'distance': 100})
return svg
def get_tokenizing_mapping(prompt: str, tokenizer: Any) -> List[List[int]]:
tokens = tokenizer.tokenize(prompt)
merge_idxs = []
words = []
curr_idxs = []
curr_word = ''
for i, token in enumerate(tokens):
curr_idxs.append(i + 1) # because of the [CLS] token
curr_word += token
if '</w>' in token:
merge_idxs.append(curr_idxs)
curr_idxs = []
words.append(curr_word[:-4])
curr_word = ''
return merge_idxs, words
def get_args():
model_id_map = {
'v1': 'runwayml/stable-diffusion-v1-5',
'v2-base': 'stabilityai/stable-diffusion-2-base',
'v2-large': 'stabilityai/stable-diffusion-2',
'v2-1-base': 'stabilityai/stable-diffusion-2-1-base',
'v2-1-large': 'stabilityai/stable-diffusion-2-1',
}
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', type=str, default='v2-1-base', choices=list(model_id_map.keys()), help="which diffusion model to use")
parser.add_argument('--seed', '-s', type=int, default=0, help="the random seed")
parser.add_argument('--port', '-p', type=int, default=8080, help="the port to launch the demo")
parser.add_argument('--no-cuda', action='store_true', help="Use CPUs instead of GPUs")
args = parser.parse_args()
args.model = model_id_map[args.model]
return args
def main():
args = get_args()
plt.switch_backend('agg')
device = "cpu" if args.no_cuda else "cuda"
pipe = StableDiffusionPipeline.from_pretrained(args.model, use_auth_token=True).to(device)
lock = Lock()
@torch.no_grad()
def update_dropdown(prompt):
tokens = [''] + [x.text for x in cached_nlp(prompt) if x.pos_ == 'ADJ']
return gr.Dropdown.update(choices=tokens), dependency(prompt)
@torch.no_grad()
def plot(prompt, choice, replaced_word, inf_steps, is_random_seed):
new_prompt = prompt.replace(',', ', ').replace('.', '. ')
if choice:
if not replaced_word:
replaced_word = '.'
new_prompt = [replaced_word if tok.text == choice else tok.text for tok in cached_nlp(prompt)]
new_prompt = ' '.join(new_prompt)
merge_idxs, words = get_tokenizing_mapping(prompt, pipe.tokenizer)
with auto_autocast(dtype=torch.float16), lock:
try:
plt.close('all')
plt.clf()
except:
pass
seed = int(time.time()) if is_random_seed else args.seed
gen = set_seed(seed)
prompt = prompt.replace(',', ', ').replace('.', '. ') # hacky fix to address later
if choice:
new_prompt = new_prompt.replace(',', ', ').replace('.', '. ') # hacky fix to address later
with trace(pipe, save_heads=new_prompt != prompt) as tc:
out = pipe(prompt, num_inference_steps=inf_steps, generator=gen)
image = np.array(out.images[0]) / 255
heat_map = tc.compute_global_heat_map()
if new_prompt == prompt:
image2 = image
else:
gen = set_seed(seed)
with trace(pipe, load_heads=False) as tc:
out2 = pipe(new_prompt, num_inference_steps=inf_steps, generator=gen)
image2 = np.array(out2.images[0]) / 255
else:
with trace(pipe, load_heads=False, save_heads=False) as tc:
out = pipe(prompt, num_inference_steps=inf_steps, generator=gen)
image = np.array(out.images[0]) / 255
heat_map = tc.compute_global_heat_map()
# the main image
if new_prompt == prompt:
fig, ax = plt.subplots()
ax.imshow(image)
ax.set_xticks([])
ax.set_yticks([])
else:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(image)
if choice:
ax[1].imshow(image2)
ax[0].set_title(choice)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].set_title(replaced_word)
ax[1].set_xticks([])
ax[1].set_yticks([])
# the heat maps
num_cells = 4
w = int(num_cells * 3.5)
h = math.ceil(len(words) / num_cells * 4.5)
fig_soft, axs_soft = plt.subplots(math.ceil(len(words) / num_cells), num_cells, figsize=(w, h))
axs_soft = axs_soft.flatten()
with torch.cuda.amp.autocast(dtype=torch.float32):
for idx, parsed_map in enumerate(heat_map.parsed_heat_maps()):
word_ax_soft = axs_soft[idx]
word_ax_soft.set_xticks([])
word_ax_soft.set_yticks([])
parsed_map.word_heat_map.plot_overlay(out.images[0], ax=word_ax_soft)
word_ax_soft.set_title(parsed_map.word_heat_map.word, fontsize=12)
for idx in range(len(words), len(axs_soft)):
fig_soft.delaxes(axs_soft[idx])
return fig, fig_soft
with gr.Blocks(css='scrollbar.css') as demo:
md = '''# DAAM: Attention Maps for Interpreting Stable Diffusion
Check out the paper: [What the DAAM: Interpreting Stable Diffusion Using Cross Attention](http://arxiv.org/abs/2210.04885).
See our (much cleaner) [DAAM codebase](https://github.com/castorini/daam) on GitHub.
'''
gr.Markdown(md)
with gr.Row():
with gr.Column():
dropdown = gr.Dropdown([
'An angry, bald man doing research',
'A bear and a moose',
'A blue car driving through the city',
'Monkey walking with hat',
'Doing research at Comcast Applied AI labs',
'Professor Jimmy Lin from the modern University of Waterloo',
'Yann Lecun teaching machine learning on a green chalkboard',
'A brown cat eating yummy cake for her birthday',
'A brown fox, a white dog, and a blue wolf in a green field',
], label='Examples', value='An angry, bald man doing research')
text = gr.Textbox(label='Prompt', value='An angry, bald man doing research')
with gr.Row():
doc = cached_nlp('An angry, bald man doing research')
tokens = [''] + [x.text for x in doc if x.pos_ == 'ADJ']
dropdown2 = gr.Dropdown(tokens, label='Adjective to replace', interactive=True)
text2 = gr.Textbox(label='New adjective', value='')
checkbox = gr.Checkbox(value=False, label='Random seed')
slider1 = gr.Slider(15, 30, value=25, interactive=True, step=1, label='Inference steps')
submit_btn = gr.Button('Submit', elem_id='submit-btn')
viz = gr.HTML(dependency('An angry, bald man doing research'), elem_id='viz')
with gr.Column():
with gr.Tab('Images'):
p0 = gr.Plot()
with gr.Tab('DAAM Maps'):
p1 = gr.Plot()
text.change(fn=update_dropdown, inputs=[text], outputs=[dropdown2, viz])
submit_btn.click(
fn=plot,
inputs=[text, dropdown2, text2, slider1, checkbox],
outputs=[p0, p1])
dropdown.change(lambda prompt: prompt, dropdown, text)
dropdown.update()
while True:
try:
demo.launch()
except OSError:
gr.close_all()
except KeyboardInterrupt:
gr.close_all()
break
if __name__ == '__main__':
main()
|