radames commited on
Commit
809d7f5
Β·
1 Parent(s): 81a927a
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ sdxl_models/
3
+ gradio_cached_examples/
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: InstantStyle
3
  emoji: πŸ‘
4
  colorFrom: blue
5
  colorTo: purple
 
1
  ---
2
+ title: InstantStyle + SDXL Lightning
3
  emoji: πŸ‘
4
  colorFrom: blue
5
  colorTo: purple
app.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+ from diffusers import (
7
+ ControlNetModel,
8
+ StableDiffusionXLControlNetPipeline,
9
+ UNet2DConditionModel,
10
+ EulerDiscreteScheduler,
11
+ )
12
+ import spaces
13
+ import gradio as gr
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from ip_adapter import IPAdapterXL
16
+ from safetensors.torch import load_file
17
+
18
+ snapshot_download(
19
+ repo_id="h94/IP-Adapter", allow_patterns="sdxl_models/*", local_dir="."
20
+ )
21
+
22
+ # global variable
23
+ MAX_SEED = np.iinfo(np.int32).max
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
26
+
27
+ # initialization
28
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
29
+ image_encoder_path = "sdxl_models/image_encoder"
30
+ ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
31
+
32
+ controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
33
+ controlnet = ControlNetModel.from_pretrained(
34
+ controlnet_path, use_safetensors=False, torch_dtype=torch.float16
35
+ ).to(device)
36
+
37
+ # load SDXL lightnining
38
+
39
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
40
+ base_model_path,
41
+ controlnet=controlnet,
42
+ torch_dtype=torch.float16,
43
+ variant="fp16",
44
+ add_watermarker=False,
45
+ ).to(device)
46
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
47
+ pipe.scheduler.config, timestep_spacing="trailing", prediction_type="epsilon"
48
+ )
49
+ pipe.unet.load_state_dict(
50
+ load_file(
51
+ hf_hub_download(
52
+ "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
53
+ ),
54
+ device="cuda",
55
+ )
56
+ )
57
+
58
+ # load ip-adapter
59
+ # target_blocks=["block"] for original IP-Adapter
60
+ # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
61
+ # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
62
+ ip_model = IPAdapterXL(
63
+ pipe,
64
+ image_encoder_path,
65
+ ip_ckpt,
66
+ device,
67
+ target_blocks=["up_blocks.0.attentions.1"],
68
+ )
69
+
70
+
71
+ def resize_img(
72
+ input_image,
73
+ max_side=1280,
74
+ min_side=1024,
75
+ size=None,
76
+ pad_to_max_side=False,
77
+ mode=Image.BILINEAR,
78
+ base_pixel_number=64,
79
+ ):
80
+ w, h = input_image.size
81
+ if size is not None:
82
+ w_resize_new, h_resize_new = size
83
+ else:
84
+ ratio = min_side / min(h, w)
85
+ w, h = round(ratio * w), round(ratio * h)
86
+ ratio = max_side / max(h, w)
87
+ input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
88
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
89
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
90
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
91
+
92
+ if pad_to_max_side:
93
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
94
+ offset_x = (max_side - w_resize_new) // 2
95
+ offset_y = (max_side - h_resize_new) // 2
96
+ res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = (
97
+ np.array(input_image)
98
+ )
99
+ input_image = Image.fromarray(res)
100
+ return input_image
101
+
102
+
103
+ examples = [
104
+ [
105
+ "./assets/0.jpg",
106
+ None,
107
+ "a cat, masterpiece, best quality, high quality",
108
+ 1.0,
109
+ 0.0,
110
+ ],
111
+ [
112
+ "./assets/1.jpg",
113
+ None,
114
+ "a cat, masterpiece, best quality, high quality",
115
+ 1.0,
116
+ 0.0,
117
+ ],
118
+ [
119
+ "./assets/2.jpg",
120
+ None,
121
+ "a cat, masterpiece, best quality, high quality",
122
+ 1.0,
123
+ 0.0,
124
+ ],
125
+ [
126
+ "./assets/3.jpg",
127
+ None,
128
+ "a cat, masterpiece, best quality, high quality",
129
+ 1.0,
130
+ 0.0,
131
+ ],
132
+ [
133
+ "./assets/2.jpg",
134
+ "./assets/yann-lecun.jpg",
135
+ "a man, masterpiece, best quality, high quality",
136
+ 1.0,
137
+ 0.6,
138
+ ],
139
+ ]
140
+
141
+
142
+ def run_for_examples(style_image, source_image, prompt, scale, control_scale):
143
+ return create_image(
144
+ image_pil=style_image,
145
+ input_image=source_image,
146
+ prompt=prompt,
147
+ n_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
148
+ scale=scale,
149
+ control_scale=control_scale,
150
+ guidance_scale=0.0,
151
+ num_inference_steps=2,
152
+ seed=42,
153
+ target="Load only style blocks",
154
+ neg_content_prompt="",
155
+ neg_content_scale=0,
156
+ )
157
+
158
+
159
+ @spaces.GPU(enable_queue=True)
160
+ def create_image(
161
+ image_pil,
162
+ input_image,
163
+ prompt,
164
+ n_prompt,
165
+ scale,
166
+ control_scale,
167
+ guidance_scale,
168
+ num_inference_steps,
169
+ seed,
170
+ target="Load only style blocks",
171
+ neg_content_prompt=None,
172
+ neg_content_scale=0,
173
+ ):
174
+ seed = random.randint(0, MAX_SEED) if seed == -1 else seed
175
+ if target == "Load original IP-Adapter":
176
+ # target_blocks=["blocks"] for original IP-Adapter
177
+ ip_model = IPAdapterXL(
178
+ pipe, image_encoder_path, ip_ckpt, device, target_blocks=["blocks"]
179
+ )
180
+ elif target == "Load only style blocks":
181
+ # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
182
+ ip_model = IPAdapterXL(
183
+ pipe,
184
+ image_encoder_path,
185
+ ip_ckpt,
186
+ device,
187
+ target_blocks=["up_blocks.0.attentions.1"],
188
+ )
189
+ elif target == "Load style+layout block":
190
+ # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
191
+ ip_model = IPAdapterXL(
192
+ pipe,
193
+ image_encoder_path,
194
+ ip_ckpt,
195
+ device,
196
+ target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"],
197
+ )
198
+
199
+ if input_image is not None:
200
+ input_image = resize_img(input_image, max_side=1024)
201
+ cv_input_image = pil_to_cv2(input_image)
202
+ detected_map = cv2.Canny(cv_input_image, 50, 200)
203
+ canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
204
+ else:
205
+ canny_map = Image.new("RGB", (1024, 1024), color=(255, 255, 255))
206
+ control_scale = 0
207
+
208
+ if float(control_scale) == 0:
209
+ canny_map = canny_map.resize((1024, 1024))
210
+
211
+ if len(neg_content_prompt) > 0 and neg_content_scale != 0:
212
+ images = ip_model.generate(
213
+ pil_image=image_pil,
214
+ prompt=prompt,
215
+ negative_prompt=n_prompt,
216
+ scale=scale,
217
+ guidance_scale=guidance_scale,
218
+ num_samples=1,
219
+ num_inference_steps=num_inference_steps,
220
+ seed=seed,
221
+ image=canny_map,
222
+ controlnet_conditioning_scale=float(control_scale),
223
+ neg_content_prompt=neg_content_prompt,
224
+ neg_content_scale=neg_content_scale,
225
+ )
226
+ else:
227
+ images = ip_model.generate(
228
+ pil_image=image_pil,
229
+ prompt=prompt,
230
+ negative_prompt=n_prompt,
231
+ scale=scale,
232
+ guidance_scale=guidance_scale,
233
+ num_samples=1,
234
+ num_inference_steps=num_inference_steps,
235
+ seed=seed,
236
+ image=canny_map,
237
+ controlnet_conditioning_scale=float(control_scale),
238
+ )
239
+ return images
240
+
241
+
242
+ def pil_to_cv2(image_pil):
243
+ image_np = np.array(image_pil)
244
+ image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
245
+ return image_cv2
246
+
247
+
248
+ # Description
249
+ title = r"""
250
+ <h1 align="center">InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation</h1>
251
+ """
252
+
253
+ description = r"""
254
+ <b>Forked from <a href='https://github.com/InstantStyle/InstantStyle' target='_blank'><b>InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation</b></a>.<br>
255
+ <b>Model by <a href='https://huggingface.co/ByteDance/SDXL-Lightning' target='_blank'>SDXL Lightning</a> and <a href='https://huggingface.co/h94/IP-Adapter' target='_blank'>IP-Adapter</a>.</b><br>
256
+ """
257
+
258
+ article = r"""
259
+ ---
260
+ πŸ“ **Citation**
261
+ <br>
262
+ If our work is helpful for your research or applications, please cite us via:
263
+ ```bibtex
264
+ @article{wang2024instantstyle,
265
+ title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
266
+ author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
267
+ journal={arXiv preprint arXiv:2404.02733},
268
+ year={2024}
269
+ }
270
+ ```
271
+ πŸ“§ **Contact**
272
+ <br>
273
+ If you have any questions, please feel free to open an issue or directly reach us out at <b>[email protected]</b>.
274
+ """
275
+
276
+ block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False)
277
+ with block:
278
+ # description
279
+ gr.Markdown(title)
280
+ gr.Markdown(description)
281
+
282
+ with gr.Tabs():
283
+ with gr.Row():
284
+ with gr.Column():
285
+ with gr.Row():
286
+ with gr.Column():
287
+ image_pil = gr.Image(label="Style Image", type="pil")
288
+
289
+ target = gr.Radio(
290
+ [
291
+ "Load only style blocks",
292
+ "Load style+layout block",
293
+ "Load original IP-Adapter",
294
+ ],
295
+ value="Load only style blocks",
296
+ label="Style mode",
297
+ )
298
+
299
+ prompt = gr.Textbox(
300
+ label="Prompt",
301
+ value="a cat, masterpiece, best quality, high quality",
302
+ )
303
+
304
+ scale = gr.Slider(
305
+ minimum=0, maximum=2.0, step=0.01, value=1.0, label="Scale"
306
+ )
307
+
308
+ with gr.Accordion(open=False, label="Advanced Options"):
309
+ with gr.Column():
310
+ src_image_pil = gr.Image(
311
+ label="Source Image (optional)", type="pil"
312
+ )
313
+ control_scale = gr.Slider(
314
+ minimum=0,
315
+ maximum=1.0,
316
+ step=0.01,
317
+ value=0.5,
318
+ label="Controlnet conditioning scale",
319
+ )
320
+
321
+ n_prompt = gr.Textbox(
322
+ label="Neg Prompt",
323
+ value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
324
+ )
325
+
326
+ neg_content_prompt = gr.Textbox(
327
+ label="Neg Content Prompt", value=""
328
+ )
329
+ neg_content_scale = gr.Slider(
330
+ minimum=0,
331
+ maximum=1.0,
332
+ step=0.01,
333
+ value=0.5,
334
+ label="Neg Content Scale",
335
+ )
336
+
337
+ guidance_scale = gr.Slider(
338
+ minimum=1,
339
+ maximum=15.0,
340
+ step=0.01,
341
+ value=5.0,
342
+ label="guidance scale",
343
+ )
344
+ num_inference_steps = gr.Slider(
345
+ minimum=2,
346
+ maximum=50.0,
347
+ step=1.0,
348
+ value=2,
349
+ label="num inference steps",
350
+ )
351
+ seed = gr.Slider(
352
+ minimum=-1,
353
+ maximum=MAX_SEED,
354
+ value=-1,
355
+ step=1,
356
+ label="Seed Value",
357
+ )
358
+
359
+ generate_button = gr.Button("Generate Image")
360
+
361
+ with gr.Column():
362
+ generated_image = gr.Gallery(label="Generated Image")
363
+
364
+ generate_button.click(
365
+ fn=create_image,
366
+ inputs=[
367
+ image_pil,
368
+ src_image_pil,
369
+ prompt,
370
+ n_prompt,
371
+ scale,
372
+ control_scale,
373
+ guidance_scale,
374
+ num_inference_steps,
375
+ seed,
376
+ target,
377
+ neg_content_prompt,
378
+ neg_content_scale,
379
+ ],
380
+ outputs=[generated_image],
381
+ )
382
+
383
+ gr.Examples(
384
+ examples=examples,
385
+ inputs=[image_pil, src_image_pil, prompt, scale, control_scale],
386
+ fn=run_for_examples,
387
+ outputs=[generated_image],
388
+ cache_examples=True,
389
+ )
390
+
391
+ gr.Markdown(article)
392
+
393
+ block.launch()
assets/0.jpg ADDED
assets/1.jpg ADDED
assets/2.jpg ADDED
assets/3.jpg ADDED
assets/yann-lecun.jpg ADDED
ip_adapter/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
2
+
3
+ __all__ = [
4
+ "IPAdapter",
5
+ "IPAdapterPlus",
6
+ "IPAdapterPlusXL",
7
+ "IPAdapterXL",
8
+ "IPAdapterFull",
9
+ ]
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=None,
15
+ cross_attention_dim=None,
16
+ ):
17
+ super().__init__()
18
+
19
+ def __call__(
20
+ self,
21
+ attn,
22
+ hidden_states,
23
+ encoder_hidden_states=None,
24
+ attention_mask=None,
25
+ temb=None,
26
+ ):
27
+ residual = hidden_states
28
+
29
+ if attn.spatial_norm is not None:
30
+ hidden_states = attn.spatial_norm(hidden_states, temb)
31
+
32
+ input_ndim = hidden_states.ndim
33
+
34
+ if input_ndim == 4:
35
+ batch_size, channel, height, width = hidden_states.shape
36
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
+
38
+ batch_size, sequence_length, _ = (
39
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+ )
41
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
42
+
43
+ if attn.group_norm is not None:
44
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45
+
46
+ query = attn.to_q(hidden_states)
47
+
48
+ if encoder_hidden_states is None:
49
+ encoder_hidden_states = hidden_states
50
+ elif attn.norm_cross:
51
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
+
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+
56
+ query = attn.head_to_batch_dim(query)
57
+ key = attn.head_to_batch_dim(key)
58
+ value = attn.head_to_batch_dim(value)
59
+
60
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
61
+ hidden_states = torch.bmm(attention_probs, value)
62
+ hidden_states = attn.batch_to_head_dim(hidden_states)
63
+
64
+ # linear proj
65
+ hidden_states = attn.to_out[0](hidden_states)
66
+ # dropout
67
+ hidden_states = attn.to_out[1](hidden_states)
68
+
69
+ if input_ndim == 4:
70
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
71
+
72
+ if attn.residual_connection:
73
+ hidden_states = hidden_states + residual
74
+
75
+ hidden_states = hidden_states / attn.rescale_output_factor
76
+
77
+ return hidden_states
78
+
79
+
80
+ class IPAttnProcessor(nn.Module):
81
+ r"""
82
+ Attention processor for IP-Adapater.
83
+ Args:
84
+ hidden_size (`int`):
85
+ The hidden size of the attention layer.
86
+ cross_attention_dim (`int`):
87
+ The number of channels in the `encoder_hidden_states`.
88
+ scale (`float`, defaults to 1.0):
89
+ the weight scale of image prompt.
90
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
91
+ The context length of the image features.
92
+ """
93
+
94
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
95
+ super().__init__()
96
+
97
+ self.hidden_size = hidden_size
98
+ self.cross_attention_dim = cross_attention_dim
99
+ self.scale = scale
100
+ self.num_tokens = num_tokens
101
+ self.skip = skip
102
+
103
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
105
+
106
+ def __call__(
107
+ self,
108
+ attn,
109
+ hidden_states,
110
+ encoder_hidden_states=None,
111
+ attention_mask=None,
112
+ temb=None,
113
+ ):
114
+ residual = hidden_states
115
+
116
+ if attn.spatial_norm is not None:
117
+ hidden_states = attn.spatial_norm(hidden_states, temb)
118
+
119
+ input_ndim = hidden_states.ndim
120
+
121
+ if input_ndim == 4:
122
+ batch_size, channel, height, width = hidden_states.shape
123
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
124
+
125
+ batch_size, sequence_length, _ = (
126
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
127
+ )
128
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
129
+
130
+ if attn.group_norm is not None:
131
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = attn.to_q(hidden_states)
134
+
135
+ if encoder_hidden_states is None:
136
+ encoder_hidden_states = hidden_states
137
+ else:
138
+ # get encoder_hidden_states, ip_hidden_states
139
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
140
+ encoder_hidden_states, ip_hidden_states = (
141
+ encoder_hidden_states[:, :end_pos, :],
142
+ encoder_hidden_states[:, end_pos:, :],
143
+ )
144
+ if attn.norm_cross:
145
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
146
+
147
+ key = attn.to_k(encoder_hidden_states)
148
+ value = attn.to_v(encoder_hidden_states)
149
+
150
+ query = attn.head_to_batch_dim(query)
151
+ key = attn.head_to_batch_dim(key)
152
+ value = attn.head_to_batch_dim(value)
153
+
154
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
155
+ hidden_states = torch.bmm(attention_probs, value)
156
+ hidden_states = attn.batch_to_head_dim(hidden_states)
157
+
158
+ if not self.skip:
159
+ # for ip-adapter
160
+ ip_key = self.to_k_ip(ip_hidden_states)
161
+ ip_value = self.to_v_ip(ip_hidden_states)
162
+
163
+ ip_key = attn.head_to_batch_dim(ip_key)
164
+ ip_value = attn.head_to_batch_dim(ip_value)
165
+
166
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
167
+ self.attn_map = ip_attention_probs
168
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
169
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
170
+
171
+ hidden_states = hidden_states + self.scale * ip_hidden_states
172
+
173
+ # linear proj
174
+ hidden_states = attn.to_out[0](hidden_states)
175
+ # dropout
176
+ hidden_states = attn.to_out[1](hidden_states)
177
+
178
+ if input_ndim == 4:
179
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
180
+
181
+ if attn.residual_connection:
182
+ hidden_states = hidden_states + residual
183
+
184
+ hidden_states = hidden_states / attn.rescale_output_factor
185
+
186
+ return hidden_states
187
+
188
+
189
+ class AttnProcessor2_0(torch.nn.Module):
190
+ r"""
191
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ hidden_size=None,
197
+ cross_attention_dim=None,
198
+ ):
199
+ super().__init__()
200
+ if not hasattr(F, "scaled_dot_product_attention"):
201
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
202
+
203
+ def __call__(
204
+ self,
205
+ attn,
206
+ hidden_states,
207
+ encoder_hidden_states=None,
208
+ attention_mask=None,
209
+ temb=None,
210
+ ):
211
+ residual = hidden_states
212
+
213
+ if attn.spatial_norm is not None:
214
+ hidden_states = attn.spatial_norm(hidden_states, temb)
215
+
216
+ input_ndim = hidden_states.ndim
217
+
218
+ if input_ndim == 4:
219
+ batch_size, channel, height, width = hidden_states.shape
220
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
221
+
222
+ batch_size, sequence_length, _ = (
223
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
224
+ )
225
+
226
+ if attention_mask is not None:
227
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
228
+ # scaled_dot_product_attention expects attention_mask shape to be
229
+ # (batch, heads, source_length, target_length)
230
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
231
+
232
+ if attn.group_norm is not None:
233
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
234
+
235
+ query = attn.to_q(hidden_states)
236
+
237
+ if encoder_hidden_states is None:
238
+ encoder_hidden_states = hidden_states
239
+ elif attn.norm_cross:
240
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
241
+
242
+ key = attn.to_k(encoder_hidden_states)
243
+ value = attn.to_v(encoder_hidden_states)
244
+
245
+ inner_dim = key.shape[-1]
246
+ head_dim = inner_dim // attn.heads
247
+
248
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+
250
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
252
+
253
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
254
+ # TODO: add support for attn.scale when we move to Torch 2.1
255
+ hidden_states = F.scaled_dot_product_attention(
256
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
257
+ )
258
+
259
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
260
+ hidden_states = hidden_states.to(query.dtype)
261
+
262
+ # linear proj
263
+ hidden_states = attn.to_out[0](hidden_states)
264
+ # dropout
265
+ hidden_states = attn.to_out[1](hidden_states)
266
+
267
+ if input_ndim == 4:
268
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
269
+
270
+ if attn.residual_connection:
271
+ hidden_states = hidden_states + residual
272
+
273
+ hidden_states = hidden_states / attn.rescale_output_factor
274
+
275
+ return hidden_states
276
+
277
+
278
+ class IPAttnProcessor2_0(torch.nn.Module):
279
+ r"""
280
+ Attention processor for IP-Adapater for PyTorch 2.0.
281
+ Args:
282
+ hidden_size (`int`):
283
+ The hidden size of the attention layer.
284
+ cross_attention_dim (`int`):
285
+ The number of channels in the `encoder_hidden_states`.
286
+ scale (`float`, defaults to 1.0):
287
+ the weight scale of image prompt.
288
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
289
+ The context length of the image features.
290
+ """
291
+
292
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
293
+ super().__init__()
294
+
295
+ if not hasattr(F, "scaled_dot_product_attention"):
296
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
297
+
298
+ self.hidden_size = hidden_size
299
+ self.cross_attention_dim = cross_attention_dim
300
+ self.scale = scale
301
+ self.num_tokens = num_tokens
302
+ self.skip = skip
303
+
304
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
305
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
306
+
307
+ def __call__(
308
+ self,
309
+ attn,
310
+ hidden_states,
311
+ encoder_hidden_states=None,
312
+ attention_mask=None,
313
+ temb=None,
314
+ ):
315
+ residual = hidden_states
316
+
317
+ if attn.spatial_norm is not None:
318
+ hidden_states = attn.spatial_norm(hidden_states, temb)
319
+
320
+ input_ndim = hidden_states.ndim
321
+
322
+ if input_ndim == 4:
323
+ batch_size, channel, height, width = hidden_states.shape
324
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
325
+
326
+ batch_size, sequence_length, _ = (
327
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
328
+ )
329
+
330
+ if attention_mask is not None:
331
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
332
+ # scaled_dot_product_attention expects attention_mask shape to be
333
+ # (batch, heads, source_length, target_length)
334
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
335
+
336
+ if attn.group_norm is not None:
337
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
338
+
339
+ query = attn.to_q(hidden_states)
340
+
341
+ if encoder_hidden_states is None:
342
+ encoder_hidden_states = hidden_states
343
+ else:
344
+ # get encoder_hidden_states, ip_hidden_states
345
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
346
+ encoder_hidden_states, ip_hidden_states = (
347
+ encoder_hidden_states[:, :end_pos, :],
348
+ encoder_hidden_states[:, end_pos:, :],
349
+ )
350
+ if attn.norm_cross:
351
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
352
+
353
+ key = attn.to_k(encoder_hidden_states)
354
+ value = attn.to_v(encoder_hidden_states)
355
+
356
+ inner_dim = key.shape[-1]
357
+ head_dim = inner_dim // attn.heads
358
+
359
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
360
+
361
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
362
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
363
+
364
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
365
+ # TODO: add support for attn.scale when we move to Torch 2.1
366
+ hidden_states = F.scaled_dot_product_attention(
367
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
368
+ )
369
+
370
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
371
+ hidden_states = hidden_states.to(query.dtype)
372
+
373
+ if not self.skip:
374
+ # for ip-adapter
375
+ ip_key = self.to_k_ip(ip_hidden_states)
376
+ ip_value = self.to_v_ip(ip_hidden_states)
377
+
378
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
379
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
380
+
381
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
382
+ # TODO: add support for attn.scale when we move to Torch 2.1
383
+ ip_hidden_states = F.scaled_dot_product_attention(
384
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
385
+ )
386
+ with torch.no_grad():
387
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
388
+ #print(self.attn_map.shape)
389
+
390
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
391
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
392
+
393
+ hidden_states = hidden_states + self.scale * ip_hidden_states
394
+
395
+ # linear proj
396
+ hidden_states = attn.to_out[0](hidden_states)
397
+ # dropout
398
+ hidden_states = attn.to_out[1](hidden_states)
399
+
400
+ if input_ndim == 4:
401
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
402
+
403
+ if attn.residual_connection:
404
+ hidden_states = hidden_states + residual
405
+
406
+ hidden_states = hidden_states / attn.rescale_output_factor
407
+
408
+ return hidden_states
409
+
410
+
411
+ ## for controlnet
412
+ class CNAttnProcessor:
413
+ r"""
414
+ Default processor for performing attention-related computations.
415
+ """
416
+
417
+ def __init__(self, num_tokens=4):
418
+ self.num_tokens = num_tokens
419
+
420
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
421
+ residual = hidden_states
422
+
423
+ if attn.spatial_norm is not None:
424
+ hidden_states = attn.spatial_norm(hidden_states, temb)
425
+
426
+ input_ndim = hidden_states.ndim
427
+
428
+ if input_ndim == 4:
429
+ batch_size, channel, height, width = hidden_states.shape
430
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
431
+
432
+ batch_size, sequence_length, _ = (
433
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
434
+ )
435
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
436
+
437
+ if attn.group_norm is not None:
438
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
439
+
440
+ query = attn.to_q(hidden_states)
441
+
442
+ if encoder_hidden_states is None:
443
+ encoder_hidden_states = hidden_states
444
+ else:
445
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
446
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
447
+ if attn.norm_cross:
448
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
449
+
450
+ key = attn.to_k(encoder_hidden_states)
451
+ value = attn.to_v(encoder_hidden_states)
452
+
453
+ query = attn.head_to_batch_dim(query)
454
+ key = attn.head_to_batch_dim(key)
455
+ value = attn.head_to_batch_dim(value)
456
+
457
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
458
+ hidden_states = torch.bmm(attention_probs, value)
459
+ hidden_states = attn.batch_to_head_dim(hidden_states)
460
+
461
+ # linear proj
462
+ hidden_states = attn.to_out[0](hidden_states)
463
+ # dropout
464
+ hidden_states = attn.to_out[1](hidden_states)
465
+
466
+ if input_ndim == 4:
467
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
468
+
469
+ if attn.residual_connection:
470
+ hidden_states = hidden_states + residual
471
+
472
+ hidden_states = hidden_states / attn.rescale_output_factor
473
+
474
+ return hidden_states
475
+
476
+
477
+ class CNAttnProcessor2_0:
478
+ r"""
479
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
480
+ """
481
+
482
+ def __init__(self, num_tokens=4):
483
+ if not hasattr(F, "scaled_dot_product_attention"):
484
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
485
+ self.num_tokens = num_tokens
486
+
487
+ def __call__(
488
+ self,
489
+ attn,
490
+ hidden_states,
491
+ encoder_hidden_states=None,
492
+ attention_mask=None,
493
+ temb=None,
494
+ ):
495
+ residual = hidden_states
496
+
497
+ if attn.spatial_norm is not None:
498
+ hidden_states = attn.spatial_norm(hidden_states, temb)
499
+
500
+ input_ndim = hidden_states.ndim
501
+
502
+ if input_ndim == 4:
503
+ batch_size, channel, height, width = hidden_states.shape
504
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
505
+
506
+ batch_size, sequence_length, _ = (
507
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
508
+ )
509
+
510
+ if attention_mask is not None:
511
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
512
+ # scaled_dot_product_attention expects attention_mask shape to be
513
+ # (batch, heads, source_length, target_length)
514
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
515
+
516
+ if attn.group_norm is not None:
517
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
518
+
519
+ query = attn.to_q(hidden_states)
520
+
521
+ if encoder_hidden_states is None:
522
+ encoder_hidden_states = hidden_states
523
+ else:
524
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
525
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
526
+ if attn.norm_cross:
527
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
528
+
529
+ key = attn.to_k(encoder_hidden_states)
530
+ value = attn.to_v(encoder_hidden_states)
531
+
532
+ inner_dim = key.shape[-1]
533
+ head_dim = inner_dim // attn.heads
534
+
535
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
536
+
537
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
538
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
539
+
540
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
541
+ # TODO: add support for attn.scale when we move to Torch 2.1
542
+ hidden_states = F.scaled_dot_product_attention(
543
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
544
+ )
545
+
546
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
547
+ hidden_states = hidden_states.to(query.dtype)
548
+
549
+ # linear proj
550
+ hidden_states = attn.to_out[0](hidden_states)
551
+ # dropout
552
+ hidden_states = attn.to_out[1](hidden_states)
553
+
554
+ if input_ndim == 4:
555
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
556
+
557
+ if attn.residual_connection:
558
+ hidden_states = hidden_states + residual
559
+
560
+ hidden_states = hidden_states / attn.rescale_output_factor
561
+
562
+ return hidden_states
ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from PIL import Image
8
+ from safetensors import safe_open
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from .utils import is_torch2_available, get_generator
12
+
13
+ if is_torch2_available():
14
+ from .attention_processor import (
15
+ AttnProcessor2_0 as AttnProcessor,
16
+ )
17
+ from .attention_processor import (
18
+ CNAttnProcessor2_0 as CNAttnProcessor,
19
+ )
20
+ from .attention_processor import (
21
+ IPAttnProcessor2_0 as IPAttnProcessor,
22
+ )
23
+ else:
24
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
25
+ from .resampler import Resampler
26
+
27
+
28
+ class ImageProjModel(torch.nn.Module):
29
+ """Projection Model"""
30
+
31
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
32
+ super().__init__()
33
+
34
+ self.generator = None
35
+ self.cross_attention_dim = cross_attention_dim
36
+ self.clip_extra_context_tokens = clip_extra_context_tokens
37
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
38
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
39
+
40
+ def forward(self, image_embeds):
41
+ embeds = image_embeds
42
+ clip_extra_context_tokens = self.proj(embeds).reshape(
43
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
44
+ )
45
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
46
+ return clip_extra_context_tokens
47
+
48
+
49
+ class MLPProjModel(torch.nn.Module):
50
+ """SD model with image prompt"""
51
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
52
+ super().__init__()
53
+
54
+ self.proj = torch.nn.Sequential(
55
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
56
+ torch.nn.GELU(),
57
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
58
+ torch.nn.LayerNorm(cross_attention_dim)
59
+ )
60
+
61
+ def forward(self, image_embeds):
62
+ clip_extra_context_tokens = self.proj(image_embeds)
63
+ return clip_extra_context_tokens
64
+
65
+
66
+ class IPAdapter:
67
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]):
68
+ self.device = device
69
+ self.image_encoder_path = image_encoder_path
70
+ self.ip_ckpt = ip_ckpt
71
+ self.num_tokens = num_tokens
72
+ self.target_blocks = target_blocks
73
+
74
+ self.pipe = sd_pipe.to(self.device)
75
+ self.set_ip_adapter()
76
+
77
+ # load image encoder
78
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
79
+ self.device, dtype=torch.float16
80
+ )
81
+ self.clip_image_processor = CLIPImageProcessor()
82
+ # image proj model
83
+ self.image_proj_model = self.init_proj()
84
+
85
+ self.load_ip_adapter()
86
+
87
+
88
+ def init_proj(self):
89
+ image_proj_model = ImageProjModel(
90
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
91
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
92
+ clip_extra_context_tokens=self.num_tokens,
93
+ ).to(self.device, dtype=torch.float16)
94
+ return image_proj_model
95
+
96
+ def set_ip_adapter(self):
97
+ unet = self.pipe.unet
98
+ attn_procs = {}
99
+ for name in unet.attn_processors.keys():
100
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
101
+ if name.startswith("mid_block"):
102
+ hidden_size = unet.config.block_out_channels[-1]
103
+ elif name.startswith("up_blocks"):
104
+ block_id = int(name[len("up_blocks.")])
105
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
106
+ elif name.startswith("down_blocks"):
107
+ block_id = int(name[len("down_blocks.")])
108
+ hidden_size = unet.config.block_out_channels[block_id]
109
+ if cross_attention_dim is None:
110
+ attn_procs[name] = AttnProcessor()
111
+ else:
112
+ selected = False
113
+ for block_name in self.target_blocks:
114
+ if block_name in name:
115
+ selected = True
116
+ break
117
+ if selected:
118
+ attn_procs[name] = IPAttnProcessor(
119
+ hidden_size=hidden_size,
120
+ cross_attention_dim=cross_attention_dim,
121
+ scale=1.0,
122
+ num_tokens=self.num_tokens,
123
+ ).to(self.device, dtype=torch.float16)
124
+ else:
125
+ attn_procs[name] = IPAttnProcessor(
126
+ hidden_size=hidden_size,
127
+ cross_attention_dim=cross_attention_dim,
128
+ scale=1.0,
129
+ num_tokens=self.num_tokens,
130
+ skip=True
131
+ ).to(self.device, dtype=torch.float16)
132
+ unet.set_attn_processor(attn_procs)
133
+ if hasattr(self.pipe, "controlnet"):
134
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
135
+ for controlnet in self.pipe.controlnet.nets:
136
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
137
+ else:
138
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
139
+
140
+ def load_ip_adapter(self):
141
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
142
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
143
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
144
+ for key in f.keys():
145
+ if key.startswith("image_proj."):
146
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
147
+ elif key.startswith("ip_adapter."):
148
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
149
+ else:
150
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
151
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
152
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
153
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
154
+
155
+ @torch.inference_mode()
156
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
157
+ if pil_image is not None:
158
+ if isinstance(pil_image, Image.Image):
159
+ pil_image = [pil_image]
160
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
161
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
162
+ else:
163
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
164
+
165
+ if content_prompt_embeds is not None:
166
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
167
+
168
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
169
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
170
+ return image_prompt_embeds, uncond_image_prompt_embeds
171
+
172
+ def set_scale(self, scale):
173
+ for attn_processor in self.pipe.unet.attn_processors.values():
174
+ if isinstance(attn_processor, IPAttnProcessor):
175
+ attn_processor.scale = scale
176
+
177
+ def generate(
178
+ self,
179
+ pil_image=None,
180
+ clip_image_embeds=None,
181
+ prompt=None,
182
+ negative_prompt=None,
183
+ scale=1.0,
184
+ num_samples=4,
185
+ seed=None,
186
+ guidance_scale=7.5,
187
+ num_inference_steps=30,
188
+ neg_content_emb=None,
189
+ **kwargs,
190
+ ):
191
+ self.set_scale(scale)
192
+
193
+ if pil_image is not None:
194
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
195
+ else:
196
+ num_prompts = clip_image_embeds.size(0)
197
+
198
+ if prompt is None:
199
+ prompt = "best quality, high quality"
200
+ if negative_prompt is None:
201
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
202
+
203
+ if not isinstance(prompt, List):
204
+ prompt = [prompt] * num_prompts
205
+ if not isinstance(negative_prompt, List):
206
+ negative_prompt = [negative_prompt] * num_prompts
207
+
208
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
209
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb
210
+ )
211
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
212
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
213
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
214
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
215
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
216
+
217
+ with torch.inference_mode():
218
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
219
+ prompt,
220
+ device=self.device,
221
+ num_images_per_prompt=num_samples,
222
+ do_classifier_free_guidance=True,
223
+ negative_prompt=negative_prompt,
224
+ )
225
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
226
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
227
+
228
+ generator = get_generator(seed, self.device)
229
+
230
+ images = self.pipe(
231
+ prompt_embeds=prompt_embeds,
232
+ negative_prompt_embeds=negative_prompt_embeds,
233
+ guidance_scale=guidance_scale,
234
+ num_inference_steps=num_inference_steps,
235
+ generator=generator,
236
+ **kwargs,
237
+ ).images
238
+
239
+ return images
240
+
241
+
242
+ class IPAdapterXL(IPAdapter):
243
+ """SDXL"""
244
+
245
+ def generate(
246
+ self,
247
+ pil_image,
248
+ prompt=None,
249
+ negative_prompt=None,
250
+ scale=1.0,
251
+ num_samples=4,
252
+ seed=None,
253
+ num_inference_steps=30,
254
+ neg_content_emb=None,
255
+ neg_content_prompt=None,
256
+ neg_content_scale=1.0,
257
+ **kwargs,
258
+ ):
259
+ self.set_scale(scale)
260
+
261
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
262
+
263
+ if prompt is None:
264
+ prompt = "best quality, high quality"
265
+ if negative_prompt is None:
266
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
267
+
268
+ if not isinstance(prompt, List):
269
+ prompt = [prompt] * num_prompts
270
+ if not isinstance(negative_prompt, List):
271
+ negative_prompt = [negative_prompt] * num_prompts
272
+
273
+ if neg_content_emb is None:
274
+ if neg_content_prompt is not None:
275
+ with torch.inference_mode():
276
+ (
277
+ prompt_embeds_, # torch.Size([1, 77, 2048])
278
+ negative_prompt_embeds_,
279
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
280
+ negative_pooled_prompt_embeds_,
281
+ ) = self.pipe.encode_prompt(
282
+ neg_content_prompt,
283
+ num_images_per_prompt=num_samples,
284
+ do_classifier_free_guidance=True,
285
+ negative_prompt=negative_prompt,
286
+ )
287
+ pooled_prompt_embeds_ *= neg_content_scale
288
+ else:
289
+ pooled_prompt_embeds_ = neg_content_emb
290
+ else:
291
+ pooled_prompt_embeds_ = None
292
+
293
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, content_prompt_embeds=pooled_prompt_embeds_)
294
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
295
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
296
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
297
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
298
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
299
+
300
+ with torch.inference_mode():
301
+ (
302
+ prompt_embeds,
303
+ negative_prompt_embeds,
304
+ pooled_prompt_embeds,
305
+ negative_pooled_prompt_embeds,
306
+ ) = self.pipe.encode_prompt(
307
+ prompt,
308
+ num_images_per_prompt=num_samples,
309
+ do_classifier_free_guidance=True,
310
+ negative_prompt=negative_prompt,
311
+ )
312
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
313
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
314
+
315
+ self.generator = get_generator(seed, self.device)
316
+
317
+ images = self.pipe(
318
+ prompt_embeds=prompt_embeds,
319
+ negative_prompt_embeds=negative_prompt_embeds,
320
+ pooled_prompt_embeds=pooled_prompt_embeds,
321
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
322
+ num_inference_steps=num_inference_steps,
323
+ generator=self.generator,
324
+ **kwargs,
325
+ ).images
326
+
327
+ return images
328
+
329
+
330
+ class IPAdapterPlus(IPAdapter):
331
+ """IP-Adapter with fine-grained features"""
332
+
333
+ def init_proj(self):
334
+ image_proj_model = Resampler(
335
+ dim=self.pipe.unet.config.cross_attention_dim,
336
+ depth=4,
337
+ dim_head=64,
338
+ heads=12,
339
+ num_queries=self.num_tokens,
340
+ embedding_dim=self.image_encoder.config.hidden_size,
341
+ output_dim=self.pipe.unet.config.cross_attention_dim,
342
+ ff_mult=4,
343
+ ).to(self.device, dtype=torch.float16)
344
+ return image_proj_model
345
+
346
+ @torch.inference_mode()
347
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
348
+ if isinstance(pil_image, Image.Image):
349
+ pil_image = [pil_image]
350
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
351
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
352
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
353
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
354
+ uncond_clip_image_embeds = self.image_encoder(
355
+ torch.zeros_like(clip_image), output_hidden_states=True
356
+ ).hidden_states[-2]
357
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
358
+ return image_prompt_embeds, uncond_image_prompt_embeds
359
+
360
+
361
+ class IPAdapterFull(IPAdapterPlus):
362
+ """IP-Adapter with full features"""
363
+
364
+ def init_proj(self):
365
+ image_proj_model = MLPProjModel(
366
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
367
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
368
+ ).to(self.device, dtype=torch.float16)
369
+ return image_proj_model
370
+
371
+
372
+ class IPAdapterPlusXL(IPAdapter):
373
+ """SDXL"""
374
+
375
+ def init_proj(self):
376
+ image_proj_model = Resampler(
377
+ dim=1280,
378
+ depth=4,
379
+ dim_head=64,
380
+ heads=20,
381
+ num_queries=self.num_tokens,
382
+ embedding_dim=self.image_encoder.config.hidden_size,
383
+ output_dim=self.pipe.unet.config.cross_attention_dim,
384
+ ff_mult=4,
385
+ ).to(self.device, dtype=torch.float16)
386
+ return image_proj_model
387
+
388
+ @torch.inference_mode()
389
+ def get_image_embeds(self, pil_image):
390
+ if isinstance(pil_image, Image.Image):
391
+ pil_image = [pil_image]
392
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
393
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
394
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
395
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
396
+ uncond_clip_image_embeds = self.image_encoder(
397
+ torch.zeros_like(clip_image), output_hidden_states=True
398
+ ).hidden_states[-2]
399
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
400
+ return image_prompt_embeds, uncond_image_prompt_embeds
401
+
402
+ def generate(
403
+ self,
404
+ pil_image,
405
+ prompt=None,
406
+ negative_prompt=None,
407
+ scale=1.0,
408
+ num_samples=4,
409
+ seed=None,
410
+ num_inference_steps=30,
411
+ **kwargs,
412
+ ):
413
+ self.set_scale(scale)
414
+
415
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
416
+
417
+ if prompt is None:
418
+ prompt = "best quality, high quality"
419
+ if negative_prompt is None:
420
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
421
+
422
+ if not isinstance(prompt, List):
423
+ prompt = [prompt] * num_prompts
424
+ if not isinstance(negative_prompt, List):
425
+ negative_prompt = [negative_prompt] * num_prompts
426
+
427
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
428
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
429
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
430
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
431
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
432
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
433
+
434
+ with torch.inference_mode():
435
+ (
436
+ prompt_embeds,
437
+ negative_prompt_embeds,
438
+ pooled_prompt_embeds,
439
+ negative_pooled_prompt_embeds,
440
+ ) = self.pipe.encode_prompt(
441
+ prompt,
442
+ num_images_per_prompt=num_samples,
443
+ do_classifier_free_guidance=True,
444
+ negative_prompt=negative_prompt,
445
+ )
446
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
447
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
448
+
449
+ generator = get_generator(seed, self.device)
450
+
451
+ images = self.pipe(
452
+ prompt_embeds=prompt_embeds,
453
+ negative_prompt_embeds=negative_prompt_embeds,
454
+ pooled_prompt_embeds=pooled_prompt_embeds,
455
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
456
+ num_inference_steps=num_inference_steps,
457
+ generator=generator,
458
+ **kwargs,
459
+ ).images
460
+
461
+ return images
ip_adapter/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
ip_adapter/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ attn_maps = {}
7
+ def hook_fn(name):
8
+ def forward_hook(module, input, output):
9
+ if hasattr(module.processor, "attn_map"):
10
+ attn_maps[name] = module.processor.attn_map
11
+ del module.processor.attn_map
12
+
13
+ return forward_hook
14
+
15
+ def register_cross_attention_hook(unet):
16
+ for name, module in unet.named_modules():
17
+ if name.split('.')[-1].startswith('attn2'):
18
+ module.register_forward_hook(hook_fn(name))
19
+
20
+ return unet
21
+
22
+ def upscale(attn_map, target_size):
23
+ attn_map = torch.mean(attn_map, dim=0)
24
+ attn_map = attn_map.permute(1,0)
25
+ temp_size = None
26
+
27
+ for i in range(0,5):
28
+ scale = 2 ** i
29
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31
+ break
32
+
33
+ assert temp_size is not None, "temp_size cannot is None"
34
+
35
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36
+
37
+ attn_map = F.interpolate(
38
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
39
+ size=target_size,
40
+ mode='bilinear',
41
+ align_corners=False
42
+ )[0]
43
+
44
+ attn_map = torch.softmax(attn_map, dim=0)
45
+ return attn_map
46
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47
+
48
+ idx = 0 if instance_or_negative else 1
49
+ net_attn_maps = []
50
+
51
+ for name, attn_map in attn_maps.items():
52
+ attn_map = attn_map.cpu() if detach else attn_map
53
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54
+ attn_map = upscale(attn_map, image_size)
55
+ net_attn_maps.append(attn_map)
56
+
57
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58
+
59
+ return net_attn_maps
60
+
61
+ def attnmaps2images(net_attn_maps):
62
+
63
+ #total_attn_scores = 0
64
+ images = []
65
+
66
+ for attn_map in net_attn_maps:
67
+ attn_map = attn_map.cpu().numpy()
68
+ #total_attn_scores += attn_map.mean().item()
69
+
70
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
72
+ #print("norm: ", normalized_attn_map.shape)
73
+ image = Image.fromarray(normalized_attn_map)
74
+
75
+ #image = fix_save_attn_map(attn_map)
76
+ images.append(image)
77
+
78
+ #print(total_attn_scores)
79
+ return images
80
+ def is_torch2_available():
81
+ return hasattr(F, "scaled_dot_product_attention")
82
+
83
+ def get_generator(seed, device):
84
+
85
+ if seed is not None:
86
+ if isinstance(seed, list):
87
+ generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
88
+ else:
89
+ generator = torch.Generator(device).manual_seed(seed)
90
+ else:
91
+ generator = None
92
+
93
+ return generator
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.27.2
2
+ torch>=2.0.0
3
+ torchvision>=0.15.1
4
+ transformers>=4.37.1
5
+ accelerate
6
+ safetensors
7
+ einops
8
+ spaces>=0.19.4
9
+ omegaconf
10
+ peft
11
+ huggingface-hub>=0.20.2
12
+ opencv-python
13
+ gradio
14
+ controlnet_aux
15
+ gdown
16
+ peft