File size: 8,421 Bytes
d961e81
a25009a
 
05a002e
a25009a
5780ef2
b4cfcd3
 
 
 
5780ef2
b4cfcd3
5780ef2
 
e4e7ff0
5780ef2
b4cfcd3
5780ef2
b4cfcd3
1dda469
5780ef2
 
b4cfcd3
 
 
5780ef2
b4cfcd3
5780ef2
 
b4cfcd3
 
 
 
 
 
5780ef2
b4cfcd3
 
 
 
 
 
 
 
5780ef2
b4cfcd3
5780ef2
 
b4cfcd3
 
 
 
77dcb47
1dda469
 
b4cfcd3
5780ef2
b4cfcd3
77dcb47
b4cfcd3
 
 
 
 
 
5780ef2
 
b4cfcd3
 
 
5780ef2
b4cfcd3
 
 
5780ef2
b4cfcd3
 
 
 
5780ef2
b4cfcd3
 
 
5780ef2
b4cfcd3
 
 
5780ef2
b4cfcd3
 
 
 
1dda469
b4cfcd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5780ef2
b4cfcd3
 
5d11034
b4cfcd3
 
 
f74814e
b4cfcd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dda469
b4cfcd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4b2dd1
b4cfcd3
 
 
 
 
 
 
 
 
31ccd79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from huggingface_hub import HfApi
import os

api = HfApi(token=os.environ['HUGGING_FACE_HUB_TOKEN'])

import math
import time
from threading import Lock
from typing import Any, List
import argparse

import numpy as np
from diffusers import StableDiffusionPipeline
from matplotlib import pyplot as plt
import gradio as gr
import torch
from spacy import displacy

from daam import trace
from daam.utils import set_seed, cached_nlp, auto_autocast


def dependency(text):
    doc = cached_nlp(text)
    svg = displacy.render(doc, style='dep', options={'compact': True, 'distance': 100})

    return svg


def get_tokenizing_mapping(prompt: str, tokenizer: Any) -> List[List[int]]:
    tokens = tokenizer.tokenize(prompt)
    merge_idxs = []
    words = []
    curr_idxs = []
    curr_word = ''

    for i, token in enumerate(tokens):
        curr_idxs.append(i + 1)  # because of the [CLS] token
        curr_word += token
        if '</w>' in token:
            merge_idxs.append(curr_idxs)
            curr_idxs = []
            words.append(curr_word[:-4])
            curr_word = ''

    return merge_idxs, words


def get_args():
    model_id_map = {
        'v1': 'runwayml/stable-diffusion-v1-5',
        'v2-base': 'stabilityai/stable-diffusion-2-base',
        'v2-large': 'stabilityai/stable-diffusion-2',
        'v2-1-base': 'stabilityai/stable-diffusion-2-1-base',
        'v2-1-large': 'stabilityai/stable-diffusion-2-1',
    }

    parser = argparse.ArgumentParser()
    parser.add_argument('--model', '-m', type=str, default='v2-1-base', choices=list(model_id_map.keys()), help="which diffusion model to use")
    parser.add_argument('--seed', '-s', type=int, default=0, help="the random seed")
    parser.add_argument('--port', '-p', type=int, default=8080, help="the port to launch the demo")
    parser.add_argument('--no-cuda', action='store_true', help="Use CPUs instead of GPUs")
    args = parser.parse_args()
    args.model = model_id_map[args.model]
    return args


def main():
    args = get_args()
    plt.switch_backend('agg')

    device = "cpu" if args.no_cuda else "cuda"
    pipe = StableDiffusionPipeline.from_pretrained(args.model, use_auth_token=True).to(device)
    lock = Lock()

    @torch.no_grad()
    def update_dropdown(prompt):
        tokens = [''] + [x.text for x in cached_nlp(prompt) if x.pos_ == 'ADJ']
        return gr.Dropdown.update(choices=tokens), dependency(prompt)

    @torch.no_grad()
    def plot(prompt, choice, replaced_word, inf_steps, is_random_seed):
        new_prompt = prompt.replace(',', ', ').replace('.', '. ')

        if choice:
            if not replaced_word:
                replaced_word = '.'

            new_prompt = [replaced_word if tok.text == choice else tok.text for tok in cached_nlp(prompt)]
            new_prompt = ' '.join(new_prompt)

        merge_idxs, words = get_tokenizing_mapping(prompt, pipe.tokenizer)
        with auto_autocast(dtype=torch.float16), lock:
            try:
                plt.close('all')
                plt.clf()
            except:
                pass

            seed = int(time.time()) if is_random_seed else args.seed
            gen = set_seed(seed)
            prompt = prompt.replace(',', ', ').replace('.', '. ')  # hacky fix to address later

            if choice:
                new_prompt = new_prompt.replace(',', ', ').replace('.', '. ')  # hacky fix to address later

                with trace(pipe, save_heads=new_prompt != prompt) as tc:
                    out = pipe(prompt, num_inference_steps=inf_steps, generator=gen)
                    image = np.array(out.images[0]) / 255
                    heat_map = tc.compute_global_heat_map()

                if new_prompt == prompt:
                    image2 = image
                else:
                    gen = set_seed(seed)

                    with trace(pipe, load_heads=False) as tc:
                        out2 = pipe(new_prompt, num_inference_steps=inf_steps, generator=gen)
                        image2 = np.array(out2.images[0]) / 255
            else:
                with trace(pipe, load_heads=False, save_heads=False) as tc:
                    out = pipe(prompt, num_inference_steps=inf_steps, generator=gen)
                    image = np.array(out.images[0]) / 255
                    heat_map = tc.compute_global_heat_map()

        # the main image
        if new_prompt == prompt:
            fig, ax = plt.subplots()
            ax.imshow(image)
            ax.set_xticks([])
            ax.set_yticks([])
        else:
            fig, ax = plt.subplots(1, 2)
            ax[0].imshow(image)

            if choice:
                ax[1].imshow(image2)

            ax[0].set_title(choice)
            ax[0].set_xticks([])
            ax[0].set_yticks([])
            ax[1].set_title(replaced_word)
            ax[1].set_xticks([])
            ax[1].set_yticks([])

        # the heat maps
        num_cells = 4
        w = int(num_cells * 3.5)
        h = math.ceil(len(words) / num_cells * 4.5)
        fig_soft, axs_soft = plt.subplots(math.ceil(len(words) / num_cells), num_cells, figsize=(w, h))
        axs_soft = axs_soft.flatten()
        with torch.cuda.amp.autocast(dtype=torch.float32):
            for idx, parsed_map in enumerate(heat_map.parsed_heat_maps()):
                word_ax_soft = axs_soft[idx]
                word_ax_soft.set_xticks([])
                word_ax_soft.set_yticks([])
                parsed_map.word_heat_map.plot_overlay(out.images[0], ax=word_ax_soft)
                word_ax_soft.set_title(parsed_map.word_heat_map.word, fontsize=12)

        for idx in range(len(words), len(axs_soft)):
            fig_soft.delaxes(axs_soft[idx])

        return fig, fig_soft

    with gr.Blocks(css='scrollbar.css') as demo:
        md = '''# DAAM: Attention Maps for Interpreting Stable Diffusion
        Check out the paper: [What the DAAM: Interpreting Stable Diffusion Using Cross Attention](http://arxiv.org/abs/2210.04885).
        See our (much cleaner) [DAAM codebase](https://github.com/castorini/daam) on GitHub.
        '''
        gr.Markdown(md)

        with gr.Row():
            with gr.Column():
                dropdown = gr.Dropdown([
                    'An angry, bald man doing research',
                    'A bear and a moose',
                    'A blue car driving through the city',
                    'Monkey walking with hat',
                    'Doing research at Comcast Applied AI labs',
                    'Professor Jimmy Lin from the modern University of Waterloo',
                    'Yann Lecun teaching machine learning on a green chalkboard',
                    'A brown cat eating yummy cake for her birthday',
                    'A brown fox, a white dog, and a blue wolf in a green field',
                ], label='Examples', value='An angry, bald man doing research')

                text = gr.Textbox(label='Prompt', value='An angry, bald man doing research')

                with gr.Row():
                    doc = cached_nlp('An angry, bald man doing research')
                    tokens = [''] + [x.text for x in doc if x.pos_ == 'ADJ']
                    dropdown2 = gr.Dropdown(tokens, label='Adjective to replace', interactive=True)
                    text2 = gr.Textbox(label='New adjective', value='')

                checkbox = gr.Checkbox(value=False, label='Random seed')
                slider1 = gr.Slider(15, 30, value=25, interactive=True, step=1, label='Inference steps')
                submit_btn = gr.Button('Submit', elem_id='submit-btn')
                viz = gr.HTML(dependency('An angry, bald man doing research'), elem_id='viz')

            with gr.Column():
                with gr.Tab('Images'):
                    p0 = gr.Plot()

                with gr.Tab('DAAM Maps'):
                    p1 = gr.Plot()

            text.change(fn=update_dropdown, inputs=[text], outputs=[dropdown2, viz])
            
            submit_btn.click(
                fn=plot,
                inputs=[text, dropdown2, text2, slider1, checkbox],
                outputs=[p0, p1])
            dropdown.change(lambda prompt: prompt, dropdown, text)
            dropdown.update()

    while True:
        try:
            demo.launch()
        except OSError:
            gr.close_all()
        except KeyboardInterrupt:
            gr.close_all()
            break


if __name__ == '__main__':
    main()