jmanhype commited on
Commit
0a72c84
·
0 Parent(s):

Initial Space setup

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +1 -0
  2. README.md +496 -0
  3. musev/__init__.py +9 -0
  4. musev/auto_prompt/__init__.py +0 -0
  5. musev/auto_prompt/attributes/__init__.py +8 -0
  6. musev/auto_prompt/attributes/attr2template.py +127 -0
  7. musev/auto_prompt/attributes/attributes.py +227 -0
  8. musev/auto_prompt/attributes/human.py +424 -0
  9. musev/auto_prompt/attributes/render.py +33 -0
  10. musev/auto_prompt/attributes/style.py +12 -0
  11. musev/auto_prompt/human.py +40 -0
  12. musev/auto_prompt/load_template.py +37 -0
  13. musev/auto_prompt/util.py +25 -0
  14. musev/data/__init__.py +0 -0
  15. musev/data/data_util.py +681 -0
  16. musev/logging.conf +32 -0
  17. musev/models/__init__.py +3 -0
  18. musev/models/attention.py +431 -0
  19. musev/models/attention_processor.py +750 -0
  20. musev/models/controlnet.py +399 -0
  21. musev/models/embeddings.py +87 -0
  22. musev/models/facein_loader.py +120 -0
  23. musev/models/ip_adapter_face_loader.py +179 -0
  24. musev/models/ip_adapter_loader.py +340 -0
  25. musev/models/referencenet.py +1216 -0
  26. musev/models/referencenet_loader.py +124 -0
  27. musev/models/resnet.py +135 -0
  28. musev/models/super_model.py +253 -0
  29. musev/models/temporal_transformer.py +308 -0
  30. musev/models/text_model.py +40 -0
  31. musev/models/transformer_2d.py +445 -0
  32. musev/models/unet_2d_blocks.py +1537 -0
  33. musev/models/unet_3d_blocks.py +1413 -0
  34. musev/models/unet_3d_condition.py +1740 -0
  35. musev/models/unet_loader.py +273 -0
  36. musev/pipelines/__init__.py +0 -0
  37. musev/pipelines/context.py +149 -0
  38. musev/pipelines/pipeline_controlnet.py +0 -0
  39. musev/pipelines/pipeline_controlnet_predictor.py +1290 -0
  40. musev/schedulers/__init__.py +6 -0
  41. musev/schedulers/scheduling_ddim.py +302 -0
  42. musev/schedulers/scheduling_ddpm.py +262 -0
  43. musev/schedulers/scheduling_dpmsolver_multistep.py +815 -0
  44. musev/schedulers/scheduling_euler_ancestral_discrete.py +356 -0
  45. musev/schedulers/scheduling_euler_discrete.py +293 -0
  46. musev/schedulers/scheduling_lcm.py +312 -0
  47. musev/utils/__init__.py +0 -0
  48. musev/utils/attention_util.py +74 -0
  49. musev/utils/convert_from_ckpt.py +963 -0
  50. musev/utils/convert_lora_safetensor_to_diffusers.py +154 -0
Dockerfile ADDED
@@ -0,0 +1 @@
 
 
1
+
README.md ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MuseV
3
+ emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # MuseV [English](README.md) [中文](README-zh.md)
12
+
13
+ <font size=5>MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising
14
+ </br>
15
+ Zhiqiang Xia <sup>\*</sup>,
16
+ Zhaokang Chen<sup>\*</sup>,
17
+ Bin Wu<sup>†</sup>,
18
+ Chao Li,
19
+ Kwok-Wai Hung,
20
+ Chao Zhan,
21
+ Yingjie He,
22
+ Wenjiang Zhou
23
+ (<sup>*</sup>co-first author, <sup>†</sup>Corresponding Author, [email protected])
24
+ </font>
25
+
26
+ **[github](https://github.com/TMElyralab/MuseV)** **[huggingface](https://huggingface.co/TMElyralab/MuseV)** **[HuggingfaceSpace](https://huggingface.co/spaces/AnchorFake/MuseVDemo)** **[project](https://tmelyralab.github.io/MuseV_Page/)** **Technical report (comming soon)**
27
+
28
+
29
+ We have setup **the world simulator vision since March 2023, believing diffusion models can simulate the world**. `MuseV` was a milestone achieved around **July 2023**. Amazed by the progress of Sora, we decided to opensource `MuseV`, hopefully it will benefit the community. Next we will move on to the promising diffusion+transformer scheme.
30
+
31
+
32
+ Update: We have released <a href="https://github.com/TMElyralab/MuseTalk" style="font-size:24px; color:red;">MuseTalk</a>, a real-time high quality lip sync model, which can be applied with MuseV as a complete virtual human generation solution.
33
+
34
+ # Overview
35
+ `MuseV` is a diffusion-based virtual human video generation framework, which
36
+ 1. supports **infinite length** generation using a novel **Visual Conditioned Parallel Denoising scheme**.
37
+ 2. checkpoint available for virtual human video generation trained on human dataset.
38
+ 3. supports Image2Video, Text2Image2Video, Video2Video.
39
+ 4. compatible with the **Stable Diffusion ecosystem**, including `base_model`, `lora`, `controlnet`, etc.
40
+ 5. supports multi reference image technology, including `IPAdapter`, `ReferenceOnly`, `ReferenceNet`, `IPAdapterFaceID`.
41
+ 6. training codes (comming very soon).
42
+
43
+ # Important bug fixes
44
+ 1. `musev_referencenet_pose`: model_name of `unet`, `ip_adapter` of Command is not correct, please use `musev_referencenet_pose` instead of `musev_referencenet`.
45
+
46
+ # News
47
+ - [03/27/2024] release `MuseV` project and trained model `musev`, `muse_referencenet`.
48
+ - [03/30/2024] add huggingface space gradio to generate video in gui
49
+
50
+ ## Model
51
+ ### Overview of model structure
52
+ ![model_structure](./data/models/musev_structure.png)
53
+ ### Parallel denoising
54
+ ![parallel_denoise](./data//models/parallel_denoise.png)
55
+
56
+ ## Cases
57
+ All frames were generated directly from text2video model, without any post process.
58
+ MoreCase is in **[project](https://tmelyralab.github.io/MuseV_Page/)**, including **1-2 minute video**.
59
+
60
+ <!-- # TODO: // use youtu video link? -->
61
+ Examples bellow can be accessed at `configs/tasks/example.yaml`
62
+
63
+
64
+ ### Text/Image2Video
65
+
66
+ #### Human
67
+
68
+ <table class="center">
69
+ <tr style="font-weight: bolder;text-align:center;">
70
+ <td width="50%">image</td>
71
+ <td width="45%">video </td>
72
+ <td width="5%">prompt</td>
73
+ </tr>
74
+
75
+ <tr>
76
+ <td>
77
+ <img src=./data/images/yongen.jpeg width="400">
78
+ </td>
79
+ <td >
80
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/732cf1fd-25e7-494e-b462-969c9425d277" width="100" controls preload></video>
81
+ </td>
82
+ <td>(masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3)
83
+ </td>
84
+ </tr>
85
+
86
+ <tr>
87
+ <td>
88
+ <img src=./data/images/seaside4.jpeg width="400">
89
+ </td>
90
+ <td>
91
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/9b75a46c-f4e6-45ef-ad02-05729f091c8f" width="100" controls preload></video>
92
+ </td>
93
+ <td>
94
+ (masterpiece, best quality, highres:1), peaceful beautiful sea scene
95
+ </td>
96
+ </tr>
97
+ <tr>
98
+ <td>
99
+ <img src=./data/images/seaside_girl.jpeg width="400">
100
+ </td>
101
+ <td>
102
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/d0f3b401-09bf-4018-81c3-569ec24a4de9" width="100" controls preload></video>
103
+ </td>
104
+ <td>
105
+ (masterpiece, best quality, highres:1), peaceful beautiful sea scene
106
+ </td>
107
+ </tr>
108
+ <!-- guitar -->
109
+ <tr>
110
+ <td>
111
+ <img src=./data/images/boy_play_guitar.jpeg width="400">
112
+ </td>
113
+ <td>
114
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/61bf955e-7161-44c8-a498-8811c4f4eb4f" width="100" controls preload></video>
115
+ </td>
116
+ <td>
117
+ (masterpiece, best quality, highres:1), playing guitar
118
+ </td>
119
+ </tr>
120
+ <tr>
121
+ <td>
122
+ <img src=./data/images/girl_play_guitar2.jpeg width="400">
123
+ </td>
124
+ <td>
125
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/40982aa7-9f6a-4e44-8ef6-3f185d284e6a" width="100" controls preload></video>
126
+ </td>
127
+ <td>
128
+ (masterpiece, best quality, highres:1), playing guitar
129
+ </td>
130
+ </tr>
131
+ <!-- famous people -->
132
+ <tr>
133
+ <td>
134
+ <img src=./data/images/dufu.jpeg width="400">
135
+ </td>
136
+ <td>
137
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/28294baa-b996-420f-b1fb-046542adf87d" width="100" controls preload></video>
138
+ </td>
139
+ <td>
140
+ (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3),Chinese ink painting style
141
+ </td>
142
+ </tr>
143
+
144
+ <tr>
145
+ <td>
146
+ <img src=./data/images/Mona_Lisa.jpg width="400">
147
+ </td>
148
+ <td>
149
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/1ce11da6-14c6-4dcd-b7f9-7a5f060d71fb" width="100" controls preload></video>
150
+ </td>
151
+ <td>
152
+ (masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face,
153
+ soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3)
154
+ </td>
155
+ </tr>
156
+ </table >
157
+
158
+ #### Scene
159
+
160
+ <table class="center">
161
+ <tr style="font-weight: bolder;text-align:center;">
162
+ <td width="35%">image</td>
163
+ <td width="50%">video</td>
164
+ <td width="15%">prompt</td>
165
+ </tr>
166
+
167
+ <tr>
168
+ <td>
169
+ <img src=./data/images/waterfall4.jpeg width="400">
170
+ </td>
171
+ <td>
172
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/852daeb6-6b58-4931-81f9-0dddfa1b4ea5" width="100" controls preload></video>
173
+ </td>
174
+ <td>
175
+ (masterpiece, best quality, highres:1), peaceful beautiful waterfall, an
176
+ endless waterfall
177
+ </td>
178
+ </tr>
179
+
180
+ <tr>
181
+ <td>
182
+ <img src=./data/images/seaside2.jpeg width="400">
183
+ </td>
184
+ <td>
185
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/4a4d527a-6203-411f-afe9-31c992d26816" width="100" controls preload></video>
186
+ </td>
187
+ <td>(masterpiece, best quality, highres:1), peaceful beautiful sea scene
188
+ </td>
189
+ </tr>
190
+ </table >
191
+
192
+ ### VideoMiddle2Video
193
+
194
+ **pose2video**
195
+ In `duffy` mode, pose of the vision condition frame is not aligned with the first frame of control video. `posealign` will solve the problem.
196
+
197
+ <table class="center">
198
+ <tr style="font-weight: bolder;text-align:center;">
199
+ <td width="25%">image</td>
200
+ <td width="65%">video</td>
201
+ <td width="10%">prompt</td>
202
+ </tr>
203
+ <tr>
204
+ <td>
205
+ <img src=./data/images/spark_girl.png width="200">
206
+ <img src=./data/images/cyber_girl.png width="200">
207
+ </td>
208
+ <td>
209
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/484cc69d-c316-4464-a55b-3df929780a8e" width="400" controls preload></video>
210
+ </td>
211
+ <td>
212
+ (masterpiece, best quality, highres:1) , a girl is dancing, animation
213
+ </td>
214
+ </tr>
215
+ <tr>
216
+ <td>
217
+ <img src=./data/images/duffy.png width="400">
218
+ </td>
219
+ <td>
220
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/c44682e6-aafc-4730-8fc1-72825c1bacf2" width="400" controls preload></video>
221
+ </td>
222
+ <td>
223
+ (masterpiece, best quality, highres:1), is dancing, animation
224
+ </td>
225
+ </tr>
226
+ </table >
227
+
228
+ ### MuseTalk
229
+ The character of talk, `Sun Xinying` is a supermodel KOL. You can follow her on [douyin](https://www.douyin.com/user/MS4wLjABAAAAWDThbMPN_6Xmm_JgXexbOii1K-httbu2APdG8DvDyM8).
230
+
231
+ <table class="center">
232
+ <tr style="font-weight: bolder;">
233
+ <td width="35%">name</td>
234
+ <td width="50%">video</td>
235
+ </tr>
236
+
237
+ <tr>
238
+ <td>
239
+ talk
240
+ </td>
241
+ <td>
242
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/951188d1-4731-4e7f-bf40-03cacba17f2f" width="100" controls preload></video>
243
+ </td>
244
+ <tr>
245
+ <td>
246
+ sing
247
+ </td>
248
+ <td>
249
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/50b8ffab-9307-4836-99e5-947e6ce7d112" width="100" controls preload></video>
250
+ </td>
251
+ </tr>
252
+ </table >
253
+
254
+
255
+ # TODO:
256
+ - [ ] technical report (comming soon).
257
+ - [ ] training codes.
258
+ - [ ] release pretrained unet model, which is trained with controlnet、referencenet、IPAdapter, which is better on pose2video.
259
+ - [ ] support diffusion transformer generation framework.
260
+ - [ ] release `posealign` module
261
+
262
+ # Quickstart
263
+ Prepare python environment and install extra package like `diffusers`, `controlnet_aux`, `mmcm`.
264
+
265
+ ## Third party integration
266
+ Thanks for the third-party integration, which makes installation and use more convenient for everyone.
267
+ We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
268
+
269
+ ### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseV)
270
+ ### [One click integration package in windows](https://www.bilibili.com/video/BV1ux4y1v7pF/?vd_source=fe03b064abab17b79e22a692551405c3)
271
+ netdisk:https://www.123pan.com/s/Pf5Yjv-Bb9W3.html
272
+
273
+ code: glut
274
+
275
+ ## Prepare environment
276
+ You are recommended to use `docker` primarily to prepare python environment.
277
+ ### prepare python env
278
+ **Attention**: we only test with docker, there are maybe trouble with conda, or requirement. We will try to fix it. Use `docker` Please.
279
+
280
+ #### Method 1: docker
281
+ 1. pull docker image
282
+ ```bash
283
+ docker pull anchorxia/musev:latest
284
+ ```
285
+ 2. run docker
286
+ ```bash
287
+ docker run --gpus all -it --entrypoint /bin/bash anchorxia/musev:latest
288
+ ```
289
+ The default conda env is `musev`.
290
+
291
+ #### Method 2: conda
292
+ create conda environment from environment.yaml
293
+ ```
294
+ conda env create --name musev --file ./environment.yml
295
+ ```
296
+ #### Method 3: pip requirements
297
+ ```
298
+ pip install -r requirements.txt
299
+ ```
300
+ #### Prepare mmlab package
301
+ if not use docker, should install mmlab package additionally.
302
+ ```bash
303
+ pip install --no-cache-dir -U openmim
304
+ mim install mmengine
305
+ mim install "mmcv>=2.0.1"
306
+ mim install "mmdet>=3.1.0"
307
+ mim install "mmpose>=1.1.0"
308
+ ```
309
+
310
+ ### Prepare custom package / modified package
311
+ #### clone
312
+ ```bash
313
+ git clone --recursive https://github.com/TMElyralab/MuseV.git
314
+ ```
315
+ #### prepare PYTHONPATH
316
+ ```bash
317
+ current_dir=$(pwd)
318
+ export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV
319
+ export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/MMCM
320
+ export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/diffusers/src
321
+ export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/controlnet_aux/src
322
+ cd MuseV
323
+ ```
324
+
325
+ 1. `MMCM`: multi media, cross modal process package。
326
+ 1. `diffusers`: modified diffusers package based on [diffusers](https://github.com/huggingface/diffusers)
327
+ 1. `controlnet_aux`: modified based on [controlnet_aux](https://github.com/TMElyralab/controlnet_aux)
328
+
329
+
330
+ ## Download models
331
+ ```bash
332
+ git clone https://huggingface.co/TMElyralab/MuseV ./checkpoints
333
+ ```
334
+ - `motion`: text2video model, trained on tiny `ucf101` and tiny `webvid` dataset, approximately 60K videos text pairs. GPU memory consumption testing on `resolution`$=512*512$, `time_size=12`.
335
+ - `musev/unet`: only has and train `unet` motion module. `GPU memory consumption` $\approx 8G$.
336
+ - `musev_referencenet`: train `unet` module, `referencenet`, `IPAdapter`. `GPU memory consumption` $\approx 12G$.
337
+ - `unet`: `motion` module, which has `to_k`, `to_v` in `Attention` layer refer to `IPAdapter`
338
+ - `referencenet`: similar to `AnimateAnyone`
339
+ - `ip_adapter_image_proj.bin`: images clip emb project layer, refer to `IPAdapter`
340
+ - `musev_referencenet_pose`: based on `musev_referencenet`, fix `referencenet`and `controlnet_pose`, train `unet motion` and `IPAdapter`. `GPU memory consumption` $\approx 12G$
341
+ - `t2i/sd1.5`: text2image model, parameter are frozen when training motion module. Different `t2i` base_model has a significant impact.could be replaced with other t2i base.
342
+ - `majicmixRealv6Fp16`: example, download from [majicmixRealv6Fp16](https://civitai.com/models/43331?modelVersionId=94640)
343
+ - `fantasticmix_v10`: example, download from [fantasticmix_v10](https://civitai.com/models/22402?modelVersionId=26744)
344
+ - `IP-Adapter/models`: download from [IPAdapter](https://huggingface.co/h94/IP-Adapter/tree/main)
345
+ - `image_encoder`: vision clip model.
346
+ - `ip-adapter_sd15.bin`: original IPAdapter model checkpoint.
347
+ - `ip-adapter-faceid_sd15.bin`: original IPAdapter model checkpoint.
348
+
349
+ ## Inference
350
+
351
+ ### Prepare model_path
352
+ Skip this step when run example task with example inference command.
353
+ Set model path and abbreviation in config, to use abbreviation in inference script.
354
+ - T2I SD:ref to `musev/configs/model/T2I_all_model.py`
355
+ - Motion Unet: refer to `musev/configs/model/motion_model.py`
356
+ - Task: refer to `musev/configs/tasks/example.yaml`
357
+
358
+ ### musev_referencenet
359
+ #### text2video
360
+ ```bash
361
+ python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --time_size 12 --fps 12
362
+ ```
363
+ **common parameters**:
364
+ - `test_data_path`: task_path in yaml extention
365
+ - `target_datas`: sep is `,`, sample subtasks if `name` in `test_data_path` is in `target_datas`.
366
+ - `sd_model_cfg_path`: T2I sd models path, model config path or model path.
367
+ - `sd_model_name`: sd model name, which use to choose full model path in sd_model_cfg_path. multi model names with sep =`,`, or `all`
368
+ - `unet_model_cfg_path`: motion unet model config path or model path。
369
+ - `unet_model_name`: unet model name, use to get model path in `unet_model_cfg_path`, and init unet class instance in `musev/models/unet_loader.py`. multi model names with sep=`,`, or `all`. If `unet_model_cfg_path` is model path, `unet_name` must be supported in `musev/models/unet_loader.py`
370
+ - `time_size`: num_frames per diffusion denoise generation。default=`12`.
371
+ - `n_batch`: generation numbers of shot, $total\_frames=n\_batch * time\_size + n\_viscond$, default=`1`。
372
+ - `context_frames`: context_frames num. If `time_size` > `context_frame`,`time_size` window is split into many sub-windows for parallel denoising"。 default=`12`。
373
+
374
+ **To generate long videos**, there two ways:
375
+ 1. `visual conditioned parallel denoise`: set `n_batch=1`, `time_size` = all frames you want.
376
+ 1. `traditional end-to-end`: set `time_size` = `context_frames` = frames of a shot (`12`), `context_overlap` = 0;
377
+
378
+
379
+ **model parameters**:
380
+ supports `referencenet`, `IPAdapter`, `IPAdapterFaceID`, `Facein`.
381
+ - referencenet_model_name: `referencenet` model name.
382
+ - ImageClipVisionFeatureExtractor: `ImageEmbExtractor` name, extractor vision clip emb used in `IPAdapter`.
383
+ - vision_clip_model_path: `ImageClipVisionFeatureExtractor` model path.
384
+ - ip_adapter_model_name: from `IPAdapter`, it's `ImagePromptEmbProj`, used with `ImageEmbExtractor`。
385
+ - ip_adapter_face_model_name: `IPAdapterFaceID`, from `IPAdapter` to keep faceid,should set `face_image_path`。
386
+
387
+ **Some parameters that affect the motion range and generation results**:
388
+ - `video_guidance_scale`: Similar to text2image, control influence between cond and uncond,default=`3.5`
389
+ - `use_condition_image`: Whether to use the given first frame for video generation, if not generate vision condition frames first. Default=`True`.
390
+ - `redraw_condition_image`: Whether to redraw the given first frame image.
391
+ - `video_negative_prompt`: Abbreviation of full `negative_prompt` in config path. default=`V2`.
392
+
393
+
394
+ #### video2video
395
+ `t2i` base_model has a significant impact. In this case, `fantasticmix_v10` performs better than `majicmixRealv6Fp16`.
396
+ ```bash
397
+ python scripts/inference/video2video.py --sd_model_name fantasticmix_v10 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
398
+ ```
399
+ **import parameters**
400
+
401
+ Most of the parameters are same as `musev_text2video`. Special parameters of `video2video` are:
402
+ 1. need to set `video_path` as reference video in `test_data`. Now reference video supports `rgb video` and `controlnet_middle_video`。
403
+ - `which2video`: whether `rgb` video influences initial noise, influence of `rgb` is stronger than of controlnet condition.
404
+ - `controlnet_name`:whether to use `controlnet condition`, such as `dwpose,depth`.
405
+ - `video_is_middle`: `video_path` is `rgb video` or `controlnet_middle_video`. Can be set for every `test_data` in test_data_path.
406
+ - `video_has_condition`: whether condtion_images is aligned with the first frame of video_path. If Not, exrtact condition of `condition_images` firstly generate, and then align with concatation. set in `test_data`。
407
+
408
+ all controlnet_names refer to [mmcm](https://github.com/TMElyralab/MMCM/blob/main/mmcm/vision/feature_extractor/controlnet.py#L513)
409
+ ```python
410
+ ['pose', 'pose_body', 'pose_hand', 'pose_face', 'pose_hand_body', 'pose_hand_face', 'dwpose', 'dwpose_face', 'dwpose_hand', 'dwpose_body', 'dwpose_body_hand', 'canny', 'tile', 'hed', 'hed_scribble', 'depth', 'pidi', 'normal_bae', 'lineart', 'lineart_anime', 'zoe', 'sam', 'mobile_sam', 'leres', 'content', 'face_detector']
411
+ ```
412
+
413
+ ### musev_referencenet_pose
414
+ Only used for `pose2video`
415
+ train based on `musev_referencenet`, fix `referencenet`, `pose-controlnet`, and `T2I`, train `motion` module and `IPAdapter`.
416
+
417
+ `t2i` base_model has a significant impact. In this case, `fantasticmix_v10` performs better than `majicmixRealv6Fp16`.
418
+
419
+ ```bash
420
+ python scripts/inference/video2video.py --sd_model_name fantasticmix_v10 --unet_model_name musev_referencenet_pose --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet_pose -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
421
+ ```
422
+
423
+ ### musev
424
+ Only has motion module, no referencenet, requiring less gpu memory.
425
+ #### text2video
426
+ ```bash
427
+ python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --time_size 12 --fps 12
428
+ ```
429
+ #### video2video
430
+ ##### pose align
431
+ ```bash
432
+ python ./pose_align/pose_align.py --max_frame 200 --vidfn ./data/source_video/dance.mp4 --imgfn_refer ./data/images/man.jpg --outfn_ref_img_pose ./data/pose_align_results/ref_img_pose.jpg --outfn_align_pose_video ./data/pose_align_results/align_pose_video.mp4 --outfn ./data/pose_align_results/align_demo.mp4
433
+ ```
434
+ - `max_frame`: how many frames to align (count from the first frame)
435
+ - `vidfn`:real dance video in rgb
436
+ - `imgfn_refer`: refer image path
437
+ - `outfn_ref_img_pose`: output path of the pose of the refer img
438
+ - `outfn_align_pose_video`: output path of the aligned video of the refer img
439
+ - `outfn`: output path of the alignment visualization
440
+
441
+
442
+
443
+ https://github.com/TMElyralab/MuseV/assets/47803475/787d7193-ec69-43f4-a0e5-73986a808f51
444
+
445
+
446
+
447
+
448
+ then you can use the aligned pose `outfn_align_pose_video` for pose guided generation. You may need to modify the example in the config file `./configs/tasks/example.yaml`
449
+ ##### generation
450
+ ```bash
451
+ python scripts/inference/video2video.py --sd_model_name fantasticmix_v10 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
452
+ ```
453
+
454
+ ### Gradio demo
455
+ MuseV provides gradio script to generate a GUI in a local machine to generate video conveniently.
456
+
457
+ ```bash
458
+ cd scripts/gradio
459
+ python app.py
460
+ ```
461
+
462
+
463
+ # Acknowledgements
464
+
465
+ 1. MuseV has referred much to [TuneAVideo](https://github.com/showlab/Tune-A-Video), [diffusers](https://github.com/huggingface/diffusers), [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone/tree/master/src/pipelines), [animatediff](https://github.com/guoyww/AnimateDiff), [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter), [AnimateAnyone](https://arxiv.org/abs/2311.17117), [VideoFusion](https://arxiv.org/abs/2303.08320), [insightface](https://github.com/deepinsight/insightface).
466
+ 2. MuseV has been built on `ucf101` and `webvid` datasets.
467
+
468
+ Thanks for open-sourcing!
469
+
470
+ # Limitation
471
+ There are still many limitations, including
472
+
473
+ 1. Lack of generalization ability. Some visual condition image perform well, some perform bad. Some t2i pretraied model perform well, some perform bad.
474
+ 1. Limited types of video generation and limited motion range, partly because of limited types of training data. The released `MuseV` has been trained on approximately 60K human text-video pairs with resolution `512*320`. `MuseV` has greater motion range while lower video quality at lower resolution. `MuseV` tends to generate less motion range with high video quality. Trained on larger, higher resolution, higher quality text-video dataset may make `MuseV` better.
475
+ 1. Watermarks may appear because of `webvid`. A cleaner dataset without watermarks may solve this issue.
476
+ 1. Limited types of long video generation. Visual Conditioned Parallel Denoise can solve accumulated error of video generation, but the current method is only suitable for relatively fixed camera scenes.
477
+ 1. Undertrained referencenet and IP-Adapter, beacause of limited time and limited resources.
478
+ 1. Understructured code. `MuseV` supports rich and dynamic features, but with complex and unrefacted codes. It takes time to familiarize.
479
+
480
+
481
+ <!-- # Contribution 暂时不需要组织开源共建 -->
482
+ # Citation
483
+ ```bib
484
+ @article{musev,
485
+ title={MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising},
486
+ author={Xia, Zhiqiang and Chen, Zhaokang and Wu, Bin and Li, Chao and Hung, Kwok-Wai and Zhan, Chao and He, Yingjie and Zhou, Wenjiang},
487
+ journal={arxiv},
488
+ year={2024}
489
+ }
490
+ ```
491
+ # Disclaimer/License
492
+ 1. `code`: The code of MuseV is released under the MIT License. There is no limitation for both academic and commercial usage.
493
+ 1. `model`: The trained model are available for non-commercial research purposes only.
494
+ 1. `other opensource model`: Other open-source models used must comply with their license, such as `insightface`, `IP-Adapter`, `ft-mse-vae`, etc.
495
+ 1. The testdata are collected from internet, which are available for non-commercial research purposes only.
496
+ 1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.
musev/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import logging.config
4
+
5
+ # 读取日志配置文件内容
6
+ logging.config.fileConfig(os.path.join(os.path.dirname(__file__), "logging.conf"))
7
+
8
+ # 创建一个日志器logger
9
+ logger = logging.getLogger("musev")
musev/auto_prompt/__init__.py ADDED
File without changes
musev/auto_prompt/attributes/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ...utils.register import Register
2
+
3
+ AttrRegister = Register(registry_name="attributes")
4
+
5
+ # must import like bellow to ensure that each class is registered with AttrRegister:
6
+ from .human import *
7
+ from .render import *
8
+ from .style import *
musev/auto_prompt/attributes/attr2template.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ 中文
3
+ 该模块将关键词字典转化为描述文本,生成完整的提词,从而降低对比实验成本、提升控制能力和效率。
4
+ 提词(prompy)对比实验会需要控制关键属性发生变化、其他属性不变的文本对。当需要控制的属性变量发生较大变化时,靠人为复制粘贴进行完成文本撰写工作量会非常大。
5
+ 该模块主要有三种类,分别是:
6
+ 1. `BaseAttribute2Text`: 单属性文本转换类
7
+ 2. `MultiAttr2Text` 多属性文本转化类,输出`List[Tuple[str, str]`。具体如何转换为文本在 `MultiAttr2PromptTemplate`中实现。
8
+ 3. `MultiAttr2PromptTemplate`:先将2生成的多属性文本字典列表转化为完整的文本,然后再使用内置的模板`template`拼接。拼接后的文本作为实际模型输入的提词。
9
+ 1. `template`字段若没有{},且有字符,则认为输入就是完整输入网络的`prompt`;
10
+ 2. `template`字段若含有{key},则认为是带关键词的字符串目标,多个属性由`template`字符串中顺序完全决定。关键词内容由表格中相关列通过`attr2text`转化而来;
11
+ 3. `template`字段有且只含有一个{},如`a portrait of {}`,则相关内容由 `PresetMultiAttr2PromptTemplate`中预定义好的`attrs`列表指定先后顺序;
12
+
13
+ English
14
+ This module converts a keyword dictionary into descriptive text, generating complete prompts to reduce the cost of comparison experiments, and improve control and efficiency.
15
+
16
+ Prompt-based comparison experiments require text pairs where the key attributes are controlled while other attributes remain constant. When the variable attributes to be controlled undergo significant changes, manually copying and pasting to write text can be very time-consuming.
17
+
18
+ This module mainly consists of three classes:
19
+
20
+ BaseAttribute2Text: A class for converting single attribute text.
21
+ MultiAttr2Text: A class for converting multi-attribute text, outputting List[Tuple[str, str]]. The specific implementation of how to convert to text is implemented in MultiAttr2PromptTemplate.
22
+ MultiAttr2PromptTemplate: First, the list of multi-attribute text dictionaries generated by 2 is converted into complete text, and then the built-in template template is used for concatenation. The concatenated text serves as the prompt for the actual model input.
23
+ If the template field does not contain {}, and there are characters, the input is considered the complete prompt for the network.
24
+ If the template field contains {key}, it is considered a string target with keywords, and the order of multiple attributes is completely determined by the template string. The keyword content is generated by attr2text from the relevant columns in the table.
25
+ If the template field contains only one {}, such as a portrait of {}, the relevant content is specified in the order defined by the attrs list predefined in PresetMultiAttr2PromptTemplate.
26
+ """
27
+
28
+ from typing import List, Tuple, Union
29
+
30
+ from mmcm.utils.str_util import (
31
+ has_key_brace,
32
+ merge_near_same_char,
33
+ get_word_from_key_brace_string,
34
+ )
35
+
36
+ from .attributes import MultiAttr2Text, merge_multi_attrtext, AttriributeIsText
37
+ from . import AttrRegister
38
+
39
+
40
+ class MultiAttr2PromptTemplate(object):
41
+ """
42
+ 将多属性转化为模型输入文本的实际类
43
+ The actual class that converts multiple attributes into model input text is
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ template: str,
49
+ attr2text: MultiAttr2Text,
50
+ name: str,
51
+ ) -> None:
52
+ """
53
+ Args:
54
+ template (str): 提词模板, prompt template.
55
+ 如果`template`含有{key},则根据key来取值。 if the template field contains {key}, it means that the actual value for that part of the prompt will be determined by the corresponding key
56
+ 如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。if the template field in MultiAttr2PromptTemplate contains only one {} placeholder, such as "a portrait of {}", the order of the attributes is determined by the attrs list predefined in PresetMultiAttr2PromptTemplate. The values of the attributes in the texts list are concatenated in the order specified by the attrs list.
57
+ attr2text (MultiAttr2Text): 多属性转换类。Class for converting multiple attributes into text prompt.
58
+ name (str): 该多属性文本模板类的名字,便于记忆. Class Instance name
59
+ """
60
+ self.attr2text = attr2text
61
+ self.name = name
62
+ if template == "":
63
+ template = "{}"
64
+ self.template = template
65
+ self.template_has_key_brace = has_key_brace(template)
66
+
67
+ def __call__(self, attributes: dict) -> Union[str, List[str]]:
68
+ texts = self.attr2text(attributes)
69
+ if not isinstance(texts, list):
70
+ texts = [texts]
71
+ prompts = [merge_multi_attrtext(text, self.template) for text in texts]
72
+ prompts = [merge_near_same_char(prompt) for prompt in prompts]
73
+ if len(prompts) == 1:
74
+ prompts = prompts[0]
75
+ return prompts
76
+
77
+
78
+ class KeywordMultiAttr2PromptTemplate(MultiAttr2PromptTemplate):
79
+ def __init__(self, template: str, name: str = "keywords") -> None:
80
+ """关键词模板属性2文本转化类
81
+ 1. 获取关键词模板字符串中的关键词属性;
82
+ 2. 从import * 存储在locals()中变量中获取对应的类;
83
+ 3. 将集成了多属性转换类的`MultiAttr2Text`
84
+ Args:
85
+ template (str): 含有{key}的模板字符串
86
+ name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "keywords".
87
+
88
+ class for converting keyword template attributes to text
89
+ 1. Get the keyword attributes in the keyword template string;
90
+ 2. Get the corresponding class from the variables stored in locals() by import *;
91
+ 3. The `MultiAttr2Text` integrated with multiple attribute conversion classes
92
+ Args:
93
+ template (str): template string containing {key}
94
+ name (str, optional): the name of the template string, no actual use. Defaults to "keywords".
95
+ """
96
+ assert has_key_brace(
97
+ template
98
+ ), "template should have key brace, but given {}".format(template)
99
+ keywords = get_word_from_key_brace_string(template)
100
+ funcs = []
101
+ for word in keywords:
102
+ if word in AttrRegister:
103
+ func = AttrRegister[word](name=word)
104
+ else:
105
+ func = AttriributeIsText(name=word)
106
+ funcs.append(func)
107
+ attr2text = MultiAttr2Text(funcs, name=name)
108
+ super().__init__(template, attr2text, name)
109
+
110
+
111
+ class OnlySpacePromptTemplate(MultiAttr2PromptTemplate):
112
+ def __init__(self, template: str, name: str = "space_prompt") -> None:
113
+ """纯空模板,无论输入啥,都只返回空格字符串作为prompt。
114
+ Args:
115
+ template (str): 符合只输出空格字符串的模板,
116
+ name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "space_prompt".
117
+
118
+ Pure empty template, no matter what the input is, it will only return a space string as the prompt.
119
+ Args:
120
+ template (str): template that only outputs a space string,
121
+ name (str, optional): the name of the template string, no actual use. Defaults to "space_prompt".
122
+ """
123
+ attr2text = None
124
+ super().__init__(template, attr2text, name)
125
+
126
+ def __call__(self, attributes: dict) -> Union[str, List[str]]:
127
+ return ""
musev/auto_prompt/attributes/attributes.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import List, Tuple, Dict
3
+
4
+ from mmcm.utils.str_util import has_key_brace
5
+
6
+
7
+ class BaseAttribute2Text(object):
8
+ """
9
+ 属性转化为文本的基类,该类作用就是输入属性,转化为描述文本。
10
+ Base class for converting attributes to text which converts attributes to prompt text.
11
+ """
12
+
13
+ name = "base_attribute"
14
+
15
+ def __init__(self, name: str = None) -> None:
16
+ """这里类实例初始化设置`name`参数,主要是为了便于一些没有提前实现、通过字符串参数实现的新属性。
17
+ Theses class instances are initialized with the `name` parameter to facilitate the implementation of new attributes that are not implemented in advance and are implemented through string parameters.
18
+
19
+ Args:
20
+ name (str, optional): _description_. Defaults to None.
21
+ """
22
+ if name is not None:
23
+ self.name = name
24
+
25
+ def __call__(self, attributes) -> str:
26
+ raise NotImplementedError
27
+
28
+
29
+ class AttributeIsTextAndName(BaseAttribute2Text):
30
+ """
31
+ 属性文本转换功能类,将key和value拼接在一起作为文本.
32
+ class for converting attributes to text which concatenates the key and value together as text.
33
+ """
34
+
35
+ name = "attribute_is_text_name"
36
+
37
+ def __call__(self, attributes) -> str:
38
+ if attributes == "" or attributes is None:
39
+ return ""
40
+ attributes = attributes.split(",")
41
+ text = ", ".join(
42
+ [
43
+ "{} {}".format(attr, self.name) if attr != "" else ""
44
+ for attr in attributes
45
+ ]
46
+ )
47
+ return text
48
+
49
+
50
+ class AttriributeIsText(BaseAttribute2Text):
51
+ """
52
+ 属性文本转换功能类,将value作为文本.
53
+ class for converting attributes to text which only uses the value as text.
54
+ """
55
+
56
+ name = "attribute_is_text"
57
+
58
+ def __call__(self, attributes: str) -> str:
59
+ if attributes == "" or attributes is None:
60
+ return ""
61
+ attributes = str(attributes)
62
+ attributes = attributes.split(",")
63
+ text = ", ".join(["{}".format(attr) for attr in attributes])
64
+ return text
65
+
66
+
67
+ class MultiAttr2Text(object):
68
+ """将多属性组成的字典转换成完整的文本描述,目前采用简单的前后拼接方式,以`, `作为拼接符号
69
+ class for converting a dictionary of multiple attributes into a complete text description. Currently, a simple front and back splicing method is used, with `, ` as the splicing symbol.
70
+
71
+ Args:
72
+ object (_type_): _description_
73
+ """
74
+
75
+ def __init__(self, funcs: list, name) -> None:
76
+ """
77
+ Args:
78
+ funcs (list): 继承`BaseAttribute2Text`并实现了`__call__`函数的类. Inherited `BaseAttribute2Text` and implemented the `__call__` function of the class.
79
+ name (_type_): 该多属性的一个名字,可通过该类方便了解对应相关属性都是关于啥的。 name of the multi-attribute, which can be used to easily understand what the corresponding related attributes are about.
80
+ """
81
+ if not isinstance(funcs, list):
82
+ funcs = [funcs]
83
+ self.funcs = funcs
84
+ self.name = name
85
+
86
+ def __call__(
87
+ self, dct: dict, ignored_blank_str: bool = False
88
+ ) -> List[Tuple[str, str]]:
89
+ """
90
+ 有时候一个属性可能会返回多个文本,如 style cartoon会返回宫崎骏和皮克斯两种风格,采用外积增殖成多个字典。
91
+ sometimes an attribute may return multiple texts, such as style cartoon will return two styles, Miyazaki and Pixar, which are multiplied into multiple dictionaries by the outer product.
92
+ Args:
93
+ dct (dict): 多属性组成的字典,可能有self.funcs关注的属性也可能没有,self.funcs按照各自的名字按需提取关注的属性和值,并转化成文本.
94
+ Dict of multiple attributes, may or may not have the attributes that self.funcs is concerned with. self.funcs extracts the attributes and values of interest according to their respective names and converts them into text.
95
+ ignored_blank_str (bool): 如果某个attr2text返回的是空字符串,是否要过滤掉该属性。默认`False`.
96
+ If the text returned by an attr2text is an empty string, whether to filter out the attribute. Defaults to `False`.
97
+ Returns:
98
+ Union[List[List[Tuple[str, str]]], List[Tuple[str, str]]: 多组多属性文本字典列表. Multiple sets of multi-attribute text dictionaries.
99
+ """
100
+ attrs_lst = [[]]
101
+ for func in self.funcs:
102
+ if func.name in dct:
103
+ attrs = func(dct[func.name])
104
+ if isinstance(attrs, str):
105
+ for i in range(len(attrs_lst)):
106
+ attrs_lst[i].append((func.name, attrs))
107
+ else:
108
+ # 一个属性可能会返回多个文本
109
+ n_attrs = len(attrs)
110
+ new_attrs_lst = []
111
+ for n in range(n_attrs):
112
+ attrs_lst_cp = deepcopy(attrs_lst)
113
+ for i in range(len(attrs_lst_cp)):
114
+ attrs_lst_cp[i].append((func.name, attrs[n]))
115
+ new_attrs_lst.extend(attrs_lst_cp)
116
+ attrs_lst = new_attrs_lst
117
+
118
+ texts = [
119
+ [
120
+ (attr, text)
121
+ for (attr, text) in attrs
122
+ if not (text == "" and ignored_blank_str)
123
+ ]
124
+ for attrs in attrs_lst
125
+ ]
126
+ return texts
127
+
128
+
129
+ def format_tuple_texts(template: str, texts: Tuple[str, str]) -> str:
130
+ """使用含有"{}" 的模板对多属性文本元组进行拼接,形成新文本
131
+ concatenate multiple attribute text tuples using a template containing "{}" to form a new text
132
+ Args:
133
+ template (str):
134
+ texts (Tuple[str, str]): 多属性文本元组. multiple attribute text tuples
135
+
136
+ Returns:
137
+ str: 拼接后的新文本, merged new text
138
+ """
139
+ merged_text = ", ".join([text[1] for text in texts if text[1] != ""])
140
+ merged_text = template.format(merged_text)
141
+ return merged_text
142
+
143
+
144
+ def format_dct_texts(template: str, texts: Dict[str, str]) -> str:
145
+ """使用含有"{key}" 的模板对多属性文本字典进行拼接,形成新文本
146
+ concatenate multiple attribute text dictionaries using a template containing "{key}" to form a new text
147
+ Args:
148
+ template (str):
149
+ texts (Tuple[str, str]): 多属性文本字典. multiple attribute text dictionaries
150
+
151
+ Returns:
152
+ str: 拼接后的新文本, merged new text
153
+ """
154
+ merged_text = template.format(**texts)
155
+ return merged_text
156
+
157
+
158
+ def merge_multi_attrtext(texts: List[Tuple[str, str]], template: str = None) -> str:
159
+ """对多属性文本元组进行拼接,形成新文本。
160
+ 如果`template`含有{key},则根据key来取值;
161
+ 如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。
162
+
163
+ concatenate multiple attribute text tuples to form a new text.
164
+ if `template` contains {key}, the value is taken according to the key;
165
+ if `template` contains only one {}, the values in texts are concatenated in order.
166
+ Args:
167
+ texts (List[Tuple[str, str]]): Tuple[str, str]第一个str是属性名,第二个str是属性转化的文本.
168
+ Tuple[str, str] The first str is the attribute name, and the second str is the text of the attribute conversion.
169
+ template (str, optional): template . Defaults to None.
170
+
171
+ Returns:
172
+ str: 拼接后的新文本, merged new text
173
+ """
174
+ if not isinstance(texts, List):
175
+ texts = [texts]
176
+ if template is None or template == "":
177
+ template = "{}"
178
+ if has_key_brace(template):
179
+ texts = {k: v for k, v in texts}
180
+ merged_text = format_dct_texts(template, texts)
181
+ else:
182
+ merged_text = format_tuple_texts(template, texts)
183
+ return merged_text
184
+
185
+
186
+ class PresetMultiAttr2Text(MultiAttr2Text):
187
+ """预置了多种关注属性转换的类,方便维护
188
+ class for multiple attribute conversion with multiple attention attributes preset for easy maintenance
189
+
190
+ """
191
+
192
+ preset_attributes = []
193
+
194
+ def __init__(
195
+ self, funcs: List = None, use_preset: bool = True, name: str = "preset"
196
+ ) -> None:
197
+ """虽然预置了关注的属性列表和转换类,但也允许定义示例时,进行更新。
198
+ 注意`self.preset_attributes`的元素只是类名字,以便减少实例化的资源消耗。而funcs是实例化后的属性转换列表。
199
+
200
+ Although the list of attention attributes and conversion classes is preset, it is also allowed to be updated when defining an instance.
201
+ Note that the elements of `self.preset_attributes` are only class names, in order to reduce the resource consumption of instantiation. And funcs is a list of instantiated attribute conversions.
202
+
203
+ Args:
204
+ funcs (List, optional): list of funcs . Defaults to None.
205
+ use_preset (bool, optional): _description_. Defaults to True.
206
+ name (str, optional): _description_. Defaults to "preset".
207
+ """
208
+ if use_preset:
209
+ preset_funcs = self.preset()
210
+ else:
211
+ preset_funcs = []
212
+ if funcs is None:
213
+ funcs = []
214
+ if not isinstance(funcs, list):
215
+ funcs = [funcs]
216
+ funcs_names = [func.name for func in funcs]
217
+ preset_funcs = [
218
+ preset_func
219
+ for preset_func in preset_funcs
220
+ if preset_func.name not in funcs_names
221
+ ]
222
+ funcs = funcs + preset_funcs
223
+ super().__init__(funcs, name)
224
+
225
+ def preset(self):
226
+ funcs = [cls() for cls in self.preset_attributes]
227
+ return funcs
musev/auto_prompt/attributes/human.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import numpy as np
3
+ import random
4
+ import json
5
+
6
+ from .attributes import (
7
+ MultiAttr2Text,
8
+ AttriributeIsText,
9
+ AttributeIsTextAndName,
10
+ PresetMultiAttr2Text,
11
+ )
12
+ from .style import Style
13
+ from .render import Render
14
+ from . import AttrRegister
15
+
16
+
17
+ __all__ = [
18
+ "Age",
19
+ "Sex",
20
+ "Singing",
21
+ "Country",
22
+ "Lighting",
23
+ "Headwear",
24
+ "Eyes",
25
+ "Irises",
26
+ "Hair",
27
+ "Skin",
28
+ "Face",
29
+ "Smile",
30
+ "Expression",
31
+ "Clothes",
32
+ "Nose",
33
+ "Mouth",
34
+ "Beard",
35
+ "Necklace",
36
+ "KeyWords",
37
+ "InsightFace",
38
+ "Caption",
39
+ "Env",
40
+ "Decoration",
41
+ "Festival",
42
+ "SpringHeadwear",
43
+ "SpringClothes",
44
+ "Animal",
45
+ ]
46
+
47
+
48
+ @AttrRegister.register
49
+ class Sex(AttriributeIsText):
50
+ name = "sex"
51
+
52
+ def __init__(self, name: str = None) -> None:
53
+ super().__init__(name)
54
+
55
+
56
+ @AttrRegister.register
57
+ class Headwear(AttriributeIsText):
58
+ name = "headwear"
59
+
60
+ def __init__(self, name: str = None) -> None:
61
+ super().__init__(name)
62
+
63
+
64
+ @AttrRegister.register
65
+ class Expression(AttriributeIsText):
66
+ name = "expression"
67
+
68
+ def __init__(self, name: str = None) -> None:
69
+ super().__init__(name)
70
+
71
+
72
+ @AttrRegister.register
73
+ class KeyWords(AttriributeIsText):
74
+ name = "keywords"
75
+
76
+ def __init__(self, name: str = None) -> None:
77
+ super().__init__(name)
78
+
79
+
80
+ @AttrRegister.register
81
+ class Singing(AttriributeIsText):
82
+ def __init__(self, name: str = "singing") -> None:
83
+ super().__init__(name)
84
+
85
+
86
+ @AttrRegister.register
87
+ class Country(AttriributeIsText):
88
+ name = "country"
89
+
90
+ def __init__(self, name: str = None) -> None:
91
+ super().__init__(name)
92
+
93
+
94
+ @AttrRegister.register
95
+ class Clothes(AttriributeIsText):
96
+ name = "clothes"
97
+
98
+ def __init__(self, name: str = None) -> None:
99
+ super().__init__(name)
100
+
101
+
102
+ @AttrRegister.register
103
+ class Age(AttributeIsTextAndName):
104
+ name = "age"
105
+
106
+ def __init__(self, name: str = None) -> None:
107
+ super().__init__(name)
108
+
109
+ def __call__(self, attributes: str) -> str:
110
+ if not isinstance(attributes, str):
111
+ attributes = str(attributes)
112
+ attributes = attributes.split(",")
113
+ text = ", ".join(
114
+ ["{}-year-old".format(attr) if attr != "" else "" for attr in attributes]
115
+ )
116
+ return text
117
+
118
+
119
+ @AttrRegister.register
120
+ class Eyes(AttributeIsTextAndName):
121
+ name = "eyes"
122
+
123
+ def __init__(self, name: str = None) -> None:
124
+ super().__init__(name)
125
+
126
+
127
+ @AttrRegister.register
128
+ class Hair(AttributeIsTextAndName):
129
+ name = "hair"
130
+
131
+ def __init__(self, name: str = None) -> None:
132
+ super().__init__(name)
133
+
134
+
135
+ @AttrRegister.register
136
+ class Background(AttributeIsTextAndName):
137
+ name = "background"
138
+
139
+ def __init__(self, name: str = None) -> None:
140
+ super().__init__(name)
141
+
142
+
143
+ @AttrRegister.register
144
+ class Skin(AttributeIsTextAndName):
145
+ name = "skin"
146
+
147
+ def __init__(self, name: str = None) -> None:
148
+ super().__init__(name)
149
+
150
+
151
+ @AttrRegister.register
152
+ class Face(AttributeIsTextAndName):
153
+ name = "face"
154
+
155
+ def __init__(self, name: str = None) -> None:
156
+ super().__init__(name)
157
+
158
+
159
+ @AttrRegister.register
160
+ class Smile(AttributeIsTextAndName):
161
+ name = "smile"
162
+
163
+ def __init__(self, name: str = None) -> None:
164
+ super().__init__(name)
165
+
166
+
167
+ @AttrRegister.register
168
+ class Nose(AttributeIsTextAndName):
169
+ name = "nose"
170
+
171
+ def __init__(self, name: str = None) -> None:
172
+ super().__init__(name)
173
+
174
+
175
+ @AttrRegister.register
176
+ class Mouth(AttributeIsTextAndName):
177
+ name = "mouth"
178
+
179
+ def __init__(self, name: str = None) -> None:
180
+ super().__init__(name)
181
+
182
+
183
+ @AttrRegister.register
184
+ class Beard(AttriributeIsText):
185
+ name = "beard"
186
+
187
+ def __init__(self, name: str = None) -> None:
188
+ super().__init__(name)
189
+
190
+
191
+ @AttrRegister.register
192
+ class Necklace(AttributeIsTextAndName):
193
+ name = "necklace"
194
+
195
+ def __init__(self, name: str = None) -> None:
196
+ super().__init__(name)
197
+
198
+
199
+ @AttrRegister.register
200
+ class Irises(AttributeIsTextAndName):
201
+ name = "irises"
202
+
203
+ def __init__(self, name: str = None) -> None:
204
+ super().__init__(name)
205
+
206
+
207
+ @AttrRegister.register
208
+ class Lighting(AttributeIsTextAndName):
209
+ name = "lighting"
210
+
211
+ def __init__(self, name: str = None) -> None:
212
+ super().__init__(name)
213
+
214
+
215
+ PresetPortraitAttributes = [
216
+ Age,
217
+ Sex,
218
+ Singing,
219
+ Country,
220
+ Lighting,
221
+ Headwear,
222
+ Eyes,
223
+ Irises,
224
+ Hair,
225
+ Skin,
226
+ Face,
227
+ Smile,
228
+ Expression,
229
+ Clothes,
230
+ Nose,
231
+ Mouth,
232
+ Beard,
233
+ Necklace,
234
+ Style,
235
+ KeyWords,
236
+ Render,
237
+ ]
238
+
239
+
240
+ class PortraitMultiAttr2Text(PresetMultiAttr2Text):
241
+ preset_attributes = PresetPortraitAttributes
242
+
243
+ def __init__(self, funcs: list = None, use_preset=True, name="portrait") -> None:
244
+ super().__init__(funcs, use_preset, name)
245
+
246
+
247
+ @AttrRegister.register
248
+ class InsightFace(AttriributeIsText):
249
+ name = "insight_face"
250
+ face_render_dict = {
251
+ "boy": "handsome,elegant",
252
+ "girl": "gorgeous,kawaii,colorful",
253
+ }
254
+ key_words = "delicate face,beautiful eyes"
255
+
256
+ def __call__(self, attributes: str) -> str:
257
+ """将insight faces 检测的结果转化成prompt
258
+ convert the results of insight faces detection to prompt
259
+ Args:
260
+ face_list (_type_): _description_
261
+
262
+ Returns:
263
+ _type_: _description_
264
+ """
265
+ attributes = json.loads(attributes)
266
+ face_list = attributes["info"]
267
+ if len(face_list) == 0:
268
+ return ""
269
+
270
+ if attributes["image_type"] == "body":
271
+ for face in face_list:
272
+ if "black" in face and face["black"]:
273
+ return "african,dark skin"
274
+ return ""
275
+
276
+ gender_dict = {"girl": 0, "boy": 0}
277
+ face_render_list = []
278
+ black = False
279
+
280
+ for face in face_list:
281
+ if face["ratio"] < 0.02:
282
+ continue
283
+
284
+ if face["gender"] == 0:
285
+ gender_dict["girl"] += 1
286
+ face_render_list.append(self.face_render_dict["girl"])
287
+ else:
288
+ gender_dict["boy"] += 1
289
+ face_render_list.append(self.face_render_dict["boy"])
290
+
291
+ if "black" in face and face["black"]:
292
+ black = True
293
+
294
+ if len(face_render_list) == 0:
295
+ return ""
296
+ elif len(face_render_list) == 1:
297
+ solo = True
298
+ else:
299
+ solo = False
300
+
301
+ gender = ""
302
+ for g, num in gender_dict.items():
303
+ if num > 0:
304
+ if gender:
305
+ gender += ", "
306
+ gender += "{}{}".format(num, g)
307
+ if num > 1:
308
+ gender += "s"
309
+
310
+ face_render_list = ",".join(face_render_list)
311
+ face_render_list = face_render_list.split(",")
312
+ face_render = list(set(face_render_list))
313
+ face_render.sort(key=face_render_list.index)
314
+ face_render = ",".join(face_render)
315
+ if gender_dict["girl"] == 0:
316
+ face_render = "male focus," + face_render
317
+
318
+ insightface_prompt = "{},{},{}".format(gender, face_render, self.key_words)
319
+
320
+ if solo:
321
+ insightface_prompt += ",solo"
322
+ if black:
323
+ insightface_prompt = "african,dark skin," + insightface_prompt
324
+
325
+ return insightface_prompt
326
+
327
+
328
+ @AttrRegister.register
329
+ class Caption(AttriributeIsText):
330
+ name = "caption"
331
+
332
+
333
+ @AttrRegister.register
334
+ class Env(AttriributeIsText):
335
+ name = "env"
336
+ envs_list = [
337
+ "east asian architecture",
338
+ "fireworks",
339
+ "snow, snowflakes",
340
+ "snowing, snowflakes",
341
+ ]
342
+
343
+ def __call__(self, attributes: str = None) -> str:
344
+ if attributes != "" and attributes != " " and attributes is not None:
345
+ return attributes
346
+ else:
347
+ return random.choice(self.envs_list)
348
+
349
+
350
+ @AttrRegister.register
351
+ class Decoration(AttriributeIsText):
352
+ name = "decoration"
353
+
354
+ def __init__(self, name: str = None) -> None:
355
+ self.decoration_list = [
356
+ "chinese knot",
357
+ "flowers",
358
+ "food",
359
+ "lanterns",
360
+ "red envelop",
361
+ ]
362
+ super().__init__(name)
363
+
364
+ def __call__(self, attributes: str = None) -> str:
365
+ if attributes != "" and attributes != " " and attributes is not None:
366
+ return attributes
367
+ else:
368
+ return random.choice(self.decoration_list)
369
+
370
+
371
+ @AttrRegister.register
372
+ class Festival(AttriributeIsText):
373
+ name = "festival"
374
+ festival_list = ["new year"]
375
+
376
+ def __init__(self, name: str = None) -> None:
377
+ super().__init__(name)
378
+
379
+ def __call__(self, attributes: str = None) -> str:
380
+ if attributes != "" and attributes != " " and attributes is not None:
381
+ return attributes
382
+ else:
383
+ return random.choice(self.festival_list)
384
+
385
+
386
+ @AttrRegister.register
387
+ class SpringHeadwear(AttriributeIsText):
388
+ name = "spring_headwear"
389
+ headwear_list = ["rabbit ears", "rabbit ears, fur hat"]
390
+
391
+ def __call__(self, attributes: str = None) -> str:
392
+ if attributes != "" and attributes != " " and attributes is not None:
393
+ return attributes
394
+ else:
395
+ return random.choice(self.headwear_list)
396
+
397
+
398
+ @AttrRegister.register
399
+ class SpringClothes(AttriributeIsText):
400
+ name = "spring_clothes"
401
+ clothes_list = [
402
+ "mittens,chinese clothes",
403
+ "mittens,fur trim",
404
+ "mittens,red scarf",
405
+ "mittens,winter clothes",
406
+ ]
407
+
408
+ def __call__(self, attributes: str = None) -> str:
409
+ if attributes != "" and attributes != " " and attributes is not None:
410
+ return attributes
411
+ else:
412
+ return random.choice(self.clothes_list)
413
+
414
+
415
+ @AttrRegister.register
416
+ class Animal(AttriributeIsText):
417
+ name = "animal"
418
+ animal_list = ["rabbit", "holding rabbits"]
419
+
420
+ def __call__(self, attributes: str = None) -> str:
421
+ if attributes != "" and attributes != " " and attributes is not None:
422
+ return attributes
423
+ else:
424
+ return random.choice(self.animal_list)
musev/auto_prompt/attributes/render.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcm.utils.util import flatten
2
+
3
+ from .attributes import BaseAttribute2Text
4
+ from . import AttrRegister
5
+
6
+ __all__ = ["Render"]
7
+
8
+ RenderMap = {
9
+ "Epic": "artstation, epic environment, highly detailed, 8k, HD",
10
+ "HD": "8k, highly detailed",
11
+ "EpicHD": "hyper detailed, beautiful lighting, epic environment, octane render, cinematic, 8k",
12
+ "Digital": "detailed illustration, crisp lines, digital art, 8k, trending on artstation",
13
+ "Unreal1": "artstation, concept art, smooth, sharp focus, illustration, unreal engine 5, 8k",
14
+ "Unreal2": "concept art, octane render, artstation, epic environment, highly detailed, 8k",
15
+ }
16
+
17
+
18
+ @AttrRegister.register
19
+ class Render(BaseAttribute2Text):
20
+ name = "render"
21
+
22
+ def __init__(self, name: str = None) -> None:
23
+ super().__init__(name)
24
+
25
+ def __call__(self, attributes: str) -> str:
26
+ if attributes == "" or attributes is None:
27
+ return ""
28
+ attributes = attributes.split(",")
29
+ render = [RenderMap[attr] for attr in attributes if attr in RenderMap]
30
+ render = flatten(render, ignored_iterable_types=[str])
31
+ if len(render) == 1:
32
+ render = render[0]
33
+ return render
musev/auto_prompt/attributes/style.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attributes import AttriributeIsText
2
+ from . import AttrRegister
3
+
4
+ __all__ = ["Style"]
5
+
6
+
7
+ @AttrRegister.register
8
+ class Style(AttriributeIsText):
9
+ name = "style"
10
+
11
+ def __init__(self, name: str = None) -> None:
12
+ super().__init__(name)
musev/auto_prompt/human.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """负责按照人相关的属性转化成提词
2
+ """
3
+ from typing import List
4
+
5
+ from .attributes.human import PortraitMultiAttr2Text
6
+ from .attributes.attributes import BaseAttribute2Text
7
+ from .attributes.attr2template import MultiAttr2PromptTemplate
8
+
9
+
10
+ class PortraitAttr2PromptTemplate(MultiAttr2PromptTemplate):
11
+ """可以将任务字典转化为形象提词模板类
12
+ template class for converting task dictionaries into image prompt templates
13
+ Args:
14
+ MultiAttr2PromptTemplate (_type_): _description_
15
+ """
16
+
17
+ templates = "a portrait of {}"
18
+
19
+ def __init__(
20
+ self, templates: str = None, attr2text: List = None, name: str = "portrait"
21
+ ) -> None:
22
+ """
23
+
24
+ Args:
25
+ templates (str, optional): 形象提词模板,若为None,则使用默认的类属性. Defaults to None.
26
+ portrait prompt template, if None, the default class attribute is used.
27
+ attr2text (List, optional): 形象类需要新增、更新的属性列表,默认使用PortraitMultiAttr2Text中定义的形象属性. Defaults to None.
28
+ the list of attributes that need to be added or updated in the image class, by default, the image attributes defined in PortraitMultiAttr2Text are used.
29
+ name (str, optional): 该形象类的名字. Defaults to "portrait".
30
+ class name of this class instance
31
+ """
32
+ if (
33
+ attr2text is None
34
+ or isinstance(attr2text, list)
35
+ or isinstance(attr2text, BaseAttribute2Text)
36
+ ):
37
+ attr2text = PortraitMultiAttr2Text(funcs=attr2text)
38
+ if templates is None:
39
+ templates = self.templates
40
+ super().__init__(templates, attr2text, name=name)
musev/auto_prompt/load_template.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcm.utils.str_util import has_key_brace
2
+
3
+ from .human import PortraitAttr2PromptTemplate
4
+ from .attributes.attr2template import (
5
+ KeywordMultiAttr2PromptTemplate,
6
+ OnlySpacePromptTemplate,
7
+ )
8
+
9
+
10
+ def get_template_by_name(template: str, name: str = None):
11
+ """根据 template_name 确定 prompt 生成器类
12
+ choose prompt generator class according to template_name
13
+ Args:
14
+ name (str): template 的名字简称,便于指定. template name abbreviation, for easy reference
15
+
16
+ Raises:
17
+ ValueError: ValueError: 如果name不在支持的列表中,则报错. if name is not in the supported list, an error is reported.
18
+
19
+ Returns:
20
+ MultiAttr2PromptTemplate: 能够将任务字典转化为提词的 实现了__call__功能的类. class that can convert task dictionaries into prompts and implements the __call__ function
21
+
22
+ """
23
+ if template == "" or template is None:
24
+ template = OnlySpacePromptTemplate(template=template)
25
+ elif has_key_brace(template):
26
+ # if has_key_brace(template):
27
+ template = KeywordMultiAttr2PromptTemplate(template=template)
28
+ else:
29
+ if name == "portrait":
30
+ template = PortraitAttr2PromptTemplate(templates=template)
31
+ else:
32
+ raise ValueError(
33
+ "PresetAttr2PromptTemplate only support one of [portrait], but given {}".format(
34
+ name
35
+ )
36
+ )
37
+ return template
musev/auto_prompt/util.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Dict, List
3
+
4
+ from .load_template import get_template_by_name
5
+
6
+
7
+ def generate_prompts(tasks: List[Dict]) -> List[Dict]:
8
+ new_tasks = []
9
+ for task in tasks:
10
+ task["origin_prompt"] = deepcopy(task["prompt"])
11
+ # 如果prompt单元值含有模板 {},或者 没有填写任何值(默认为空模板),则使用原prompt值
12
+ if "{" not in task["prompt"] and len(task["prompt"]) != 0:
13
+ new_tasks.append(task)
14
+ else:
15
+ template = get_template_by_name(
16
+ template=task["prompt"], name=task.get("template_name", None)
17
+ )
18
+ prompts = template(task)
19
+ if not isinstance(prompts, list) and isinstance(prompts, str):
20
+ prompts = [prompts]
21
+ for prompt in prompts:
22
+ task_cp = deepcopy(task)
23
+ task_cp["prompt"] = prompt
24
+ new_tasks.append(task_cp)
25
+ return new_tasks
musev/data/__init__.py ADDED
File without changes
musev/data/data_util.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Literal, Union, Tuple
2
+ import os
3
+ import string
4
+ import logging
5
+
6
+ import torch
7
+ import numpy as np
8
+ from einops import rearrange, repeat
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def generate_tasks_of_dir(
14
+ path: str,
15
+ output_dir: str,
16
+ exts: Tuple[str],
17
+ same_dir_name: bool = False,
18
+ **kwargs,
19
+ ) -> List[Dict]:
20
+ """covert video directory into tasks
21
+
22
+ Args:
23
+ path (str): _description_
24
+ output_dir (str): _description_
25
+ exts (Tuple[str]): _description_
26
+ same_dir_name (bool, optional): 存储路径是否保留和源视频相同的父文件名. Defaults to False.
27
+ whether keep the same parent dir name as the source video
28
+ Returns:
29
+ List[Dict]: _description_
30
+ """
31
+ tasks = []
32
+ for rootdir, dirs, files in os.walk(path):
33
+ for basename in files:
34
+ if basename.lower().endswith(exts):
35
+ video_path = os.path.join(rootdir, basename)
36
+ filename, ext = basename.split(".")
37
+ rootdir_name = os.path.basename(rootdir)
38
+ if same_dir_name:
39
+ save_path = os.path.join(
40
+ output_dir, rootdir_name, f"{filename}.h5py"
41
+ )
42
+ save_dir = os.path.join(output_dir, rootdir_name)
43
+ else:
44
+ save_path = os.path.join(output_dir, f"{filename}.h5py")
45
+ save_dir = output_dir
46
+ task = {
47
+ "video_path": video_path,
48
+ "output_path": save_path,
49
+ "output_dir": save_dir,
50
+ "filename": filename,
51
+ "ext": ext,
52
+ }
53
+ task.update(kwargs)
54
+ tasks.append(task)
55
+ return tasks
56
+
57
+
58
+ def sample_by_idx(
59
+ T: int,
60
+ n_sample: int,
61
+ sample_rate: int,
62
+ sample_start_idx: int = None,
63
+ change_sample_rate: bool = False,
64
+ seed: int = None,
65
+ whether_random: bool = True,
66
+ n_independent: int = 0,
67
+ ) -> List[int]:
68
+ """given a int to represent candidate list, sample n_sample with sample_rate from the candidate list
69
+
70
+ Args:
71
+ T (int): _description_
72
+ n_sample (int): 目标采样数目. sample number
73
+ sample_rate (int): 采样率, 每隔sample_rate个采样一个. sample interval, pick one per sample_rate number
74
+ sample_start_idx (int, optional): 采样开始位置的选择. start position to sample . Defaults to 0.
75
+ change_sample_rate (bool, optional): 是否可以通过降低sample_rate的方式来完成采样. whether allow changing sample_rate to finish sample process. Defaults to False.
76
+ whether_random (bool, optional): 是否最后随机选择开始点. whether randomly choose sample start position. Defaults to False.
77
+
78
+ Raises:
79
+ ValueError: T / sample_rate should be larger than n_sample
80
+ Returns:
81
+ List[int]: 采样的索引位置. sampled index position
82
+ """
83
+ if T < n_sample:
84
+ raise ValueError(f"T({T}) < n_sample({n_sample})")
85
+ else:
86
+ if T / sample_rate < n_sample:
87
+ if not change_sample_rate:
88
+ raise ValueError(
89
+ f"T({T}) / sample_rate({sample_rate}) < n_sample({n_sample})"
90
+ )
91
+ else:
92
+ while T / sample_rate < n_sample:
93
+ sample_rate -= 1
94
+ logger.error(
95
+ f"sample_rate{sample_rate+1} is too large, decrease to {sample_rate}"
96
+ )
97
+ if sample_rate == 0:
98
+ raise ValueError("T / sample_rate < n_sample")
99
+
100
+ if sample_start_idx is None:
101
+ if whether_random:
102
+ sample_start_idx_candidates = np.arange(T - n_sample * sample_rate)
103
+ if seed is not None:
104
+ np.random.seed(seed)
105
+ sample_start_idx = np.random.choice(sample_start_idx_candidates, 1)[0]
106
+
107
+ else:
108
+ sample_start_idx = 0
109
+ sample_end_idx = sample_start_idx + sample_rate * n_sample
110
+ sample = list(range(sample_start_idx, sample_end_idx, sample_rate))
111
+ if n_independent == 0:
112
+ n_independent_sample = None
113
+ else:
114
+ left_candidate = np.array(
115
+ list(range(0, sample_start_idx)) + list(range(sample_end_idx, T))
116
+ )
117
+ if len(left_candidate) >= n_independent:
118
+ # 使用两端的剩余空间采样, use the left space to sample
119
+ n_independent_sample = np.random.choice(left_candidate, n_independent)
120
+ else:
121
+ # 当两端没有剩余采样空间时,使用任意不是sample中的帧
122
+ # if no enough space to sample, use any frame not in sample
123
+ left_candidate = np.array(list(set(range(T) - set(sample))))
124
+ n_independent_sample = np.random.choice(left_candidate, n_independent)
125
+
126
+ return sample, sample_rate, n_independent_sample
127
+
128
+
129
+ def sample_tensor_by_idx(
130
+ tensor: Union[torch.Tensor, np.ndarray],
131
+ n_sample: int,
132
+ sample_rate: int,
133
+ sample_start_idx: int = 0,
134
+ change_sample_rate: bool = False,
135
+ seed: int = None,
136
+ dim: int = 0,
137
+ return_type: Literal["numpy", "torch"] = "torch",
138
+ whether_random: bool = True,
139
+ n_independent: int = 0,
140
+ ) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]:
141
+ """sample sub_tensor
142
+
143
+ Args:
144
+ tensor (Union[torch.Tensor, np.ndarray]): _description_
145
+ n_sample (int): _description_
146
+ sample_rate (int): _description_
147
+ sample_start_idx (int, optional): _description_. Defaults to 0.
148
+ change_sample_rate (bool, optional): _description_. Defaults to False.
149
+ seed (int, optional): _description_. Defaults to None.
150
+ dim (int, optional): _description_. Defaults to 0.
151
+ return_type (Literal[&quot;numpy&quot;, &quot;torch&quot;], optional): _description_. Defaults to "torch".
152
+ whether_random (bool, optional): _description_. Defaults to True.
153
+ n_independent (int, optional): 独立于n_sample的采样数量. Defaults to 0.
154
+ n_independent sample number that is independent of n_sample
155
+
156
+ Returns:
157
+ Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: sampled tensor
158
+ """
159
+ if isinstance(tensor, np.ndarray):
160
+ tensor = torch.from_numpy(tensor)
161
+ T = tensor.shape[dim]
162
+ sample_idx, sample_rate, independent_sample_idx = sample_by_idx(
163
+ T,
164
+ n_sample,
165
+ sample_rate,
166
+ sample_start_idx,
167
+ change_sample_rate,
168
+ seed,
169
+ whether_random=whether_random,
170
+ n_independent=n_independent,
171
+ )
172
+ sample_idx = torch.LongTensor(sample_idx)
173
+ sample = torch.index_select(tensor, dim, sample_idx)
174
+ if independent_sample_idx is not None:
175
+ independent_sample_idx = torch.LongTensor(independent_sample_idx)
176
+ independent_sample = torch.index_select(tensor, dim, independent_sample_idx)
177
+ else:
178
+ independent_sample = None
179
+ independent_sample_idx = None
180
+ if return_type == "numpy":
181
+ sample = sample.cpu().numpy()
182
+ return sample, sample_idx, sample_rate, independent_sample, independent_sample_idx
183
+
184
+
185
+ def concat_two_tensor(
186
+ data1: torch.Tensor,
187
+ data2: torch.Tensor,
188
+ dim: int,
189
+ method: Literal[
190
+ "first_in_first_out", "first_in_last_out", "intertwine", "index"
191
+ ] = "first_in_first_out",
192
+ data1_index: torch.long = None,
193
+ data2_index: torch.long = None,
194
+ return_index: bool = False,
195
+ ):
196
+ """concat two tensor along dim with given method
197
+
198
+ Args:
199
+ data1 (torch.Tensor): first in data
200
+ data2 (torch.Tensor): last in data
201
+ dim (int): _description_
202
+ method (Literal[ &quot;first_in_first_out&quot;, &quot;first_in_last_out&quot;, &quot;intertwine&quot; ], optional): _description_. Defaults to "first_in_first_out".
203
+
204
+ Raises:
205
+ NotImplementedError: unsupported method
206
+ ValueError: unsupported method
207
+
208
+ Returns:
209
+ _type_: _description_
210
+ """
211
+ len_data1 = data1.shape[dim]
212
+ len_data2 = data2.shape[dim]
213
+
214
+ if method == "first_in_first_out":
215
+ res = torch.concat([data1, data2], dim=dim)
216
+ data1_index = range(len_data1)
217
+ data2_index = [len_data1 + x for x in range(len_data2)]
218
+ elif method == "first_in_last_out":
219
+ res = torch.concat([data2, data1], dim=dim)
220
+ data2_index = range(len_data2)
221
+ data1_index = [len_data2 + x for x in range(len_data1)]
222
+ elif method == "intertwine":
223
+ raise NotImplementedError("intertwine")
224
+ elif method == "index":
225
+ res = concat_two_tensor_with_index(
226
+ data1=data1,
227
+ data1_index=data1_index,
228
+ data2=data2,
229
+ data2_index=data2_index,
230
+ dim=dim,
231
+ )
232
+ else:
233
+ raise ValueError(
234
+ "only support first_in_first_out, first_in_last_out, intertwine, index"
235
+ )
236
+ if return_index:
237
+ return res, data1_index, data2_index
238
+ else:
239
+ return res
240
+
241
+
242
+ def concat_two_tensor_with_index(
243
+ data1: torch.Tensor,
244
+ data1_index: torch.LongTensor,
245
+ data2: torch.Tensor,
246
+ data2_index: torch.LongTensor,
247
+ dim: int,
248
+ ) -> torch.Tensor:
249
+ """_summary_
250
+
251
+ Args:
252
+ data1 (torch.Tensor): b1*c1*h1*w1*...
253
+ data1_index (torch.LongTensor): N, if dim=1, N=c1
254
+ data2 (torch.Tensor): b2*c2*h2*w2*...
255
+ data2_index (torch.LongTensor): M, if dim=1, M=c2
256
+ dim (int): int
257
+
258
+ Returns:
259
+ torch.Tensor: b*c*h*w*..., if dim=1, b=b1=b2, c=c1+c2, h=h1=h2, w=w1=w2,...
260
+ """
261
+ shape1 = list(data1.shape)
262
+ shape2 = list(data2.shape)
263
+ target_shape = list(shape1)
264
+ target_shape[dim] = shape1[dim] + shape2[dim]
265
+ target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
266
+ target = batch_index_copy(target, dim=dim, index=data1_index, source=data1)
267
+ target = batch_index_copy(target, dim=dim, index=data2_index, source=data2)
268
+ return target
269
+
270
+
271
+ def repeat_index_to_target_size(
272
+ index: torch.LongTensor, target_size: int
273
+ ) -> torch.LongTensor:
274
+ if len(index.shape) == 1:
275
+ index = repeat(index, "n -> b n", b=target_size)
276
+ if len(index.shape) == 2:
277
+ remainder = target_size % index.shape[0]
278
+ assert (
279
+ remainder == 0
280
+ ), f"target_size % index.shape[0] must be zero, but give {target_size % index.shape[0]}"
281
+ index = repeat(index, "b n -> (b c) n", c=int(target_size / index.shape[0]))
282
+ return index
283
+
284
+
285
+ def batch_concat_two_tensor_with_index(
286
+ data1: torch.Tensor,
287
+ data1_index: torch.LongTensor,
288
+ data2: torch.Tensor,
289
+ data2_index: torch.LongTensor,
290
+ dim: int,
291
+ ) -> torch.Tensor:
292
+ return concat_two_tensor_with_index(data1, data1_index, data2, data2_index, dim)
293
+
294
+
295
+ def interwine_two_tensor(
296
+ data1: torch.Tensor,
297
+ data2: torch.Tensor,
298
+ dim: int,
299
+ return_index: bool = False,
300
+ ) -> torch.Tensor:
301
+ shape1 = list(data1.shape)
302
+ shape2 = list(data2.shape)
303
+ target_shape = list(shape1)
304
+ target_shape[dim] = shape1[dim] + shape2[dim]
305
+ target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
306
+ data1_reshape = torch.swapaxes(data1, 0, dim)
307
+ data2_reshape = torch.swapaxes(data2, 0, dim)
308
+ target = torch.swapaxes(target, 0, dim)
309
+ total_index = set(range(target_shape[dim]))
310
+ data1_index = range(0, 2 * shape1[dim], 2)
311
+ data2_index = sorted(list(set(total_index) - set(data1_index)))
312
+ data1_index = torch.LongTensor(data1_index)
313
+ data2_index = torch.LongTensor(data2_index)
314
+ target[data1_index, ...] = data1_reshape
315
+ target[data2_index, ...] = data2_reshape
316
+ target = torch.swapaxes(target, 0, dim)
317
+ if return_index:
318
+ return target, data1_index, data2_index
319
+ else:
320
+ return target
321
+
322
+
323
+ def split_index(
324
+ indexs: torch.Tensor,
325
+ n_first: int = None,
326
+ n_last: int = None,
327
+ method: Literal[
328
+ "first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
329
+ ] = "first_in_first_out",
330
+ ):
331
+ """_summary_
332
+
333
+ Args:
334
+ indexs (List): _description_
335
+ n_first (int): _description_
336
+ n_last (int): _description_
337
+ method (Literal[ &quot;first_in_first_out&quot;, &quot;first_in_last_out&quot;, &quot;intertwine&quot;, &quot;index&quot; ], optional): _description_. Defaults to "first_in_first_out".
338
+
339
+ Raises:
340
+ NotImplementedError: _description_
341
+
342
+ Returns:
343
+ first_index: _description_
344
+ last_index:
345
+ """
346
+ # assert (
347
+ # n_first is None and n_last is None
348
+ # ), "must assign one value for n_first or n_last"
349
+ n_total = len(indexs)
350
+ if n_first is None:
351
+ n_first = n_total - n_last
352
+ if n_last is None:
353
+ n_last = n_total - n_first
354
+ assert len(indexs) == n_first + n_last
355
+ if method == "first_in_first_out":
356
+ first_index = indexs[:n_first]
357
+ last_index = indexs[n_first:]
358
+ elif method == "first_in_last_out":
359
+ first_index = indexs[n_last:]
360
+ last_index = indexs[:n_last]
361
+ elif method == "intertwine":
362
+ raise NotImplementedError
363
+ elif method == "random":
364
+ idx_ = torch.randperm(len(indexs))
365
+ first_index = indexs[idx_[:n_first]]
366
+ last_index = indexs[idx_[n_first:]]
367
+ return first_index, last_index
368
+
369
+
370
+ def split_tensor(
371
+ tensor: torch.Tensor,
372
+ dim: int,
373
+ n_first=None,
374
+ n_last=None,
375
+ method: Literal[
376
+ "first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
377
+ ] = "first_in_first_out",
378
+ need_return_index: bool = False,
379
+ ):
380
+ device = tensor.device
381
+ total = tensor.shape[dim]
382
+ if n_first is None:
383
+ n_first = total - n_last
384
+ if n_last is None:
385
+ n_last = total - n_first
386
+ indexs = torch.arange(
387
+ total,
388
+ dtype=torch.long,
389
+ device=device,
390
+ )
391
+ (
392
+ first_index,
393
+ last_index,
394
+ ) = split_index(
395
+ indexs=indexs,
396
+ n_first=n_first,
397
+ method=method,
398
+ )
399
+ first_tensor = torch.index_select(tensor, dim=dim, index=first_index)
400
+ last_tensor = torch.index_select(tensor, dim=dim, index=last_index)
401
+ if need_return_index:
402
+ return (
403
+ first_tensor,
404
+ last_tensor,
405
+ first_index,
406
+ last_index,
407
+ )
408
+ else:
409
+ return (first_tensor, last_tensor)
410
+
411
+
412
+ # TODO: 待确定batch_index_select的优化
413
+ def batch_index_select(
414
+ tensor: torch.Tensor, index: torch.LongTensor, dim: int
415
+ ) -> torch.Tensor:
416
+ """_summary_
417
+
418
+ Args:
419
+ tensor (torch.Tensor): D1*D2*D3*D4...
420
+ index (torch.LongTensor): D1*N or N, N<= tensor.shape[dim]
421
+ dim (int): dim to select
422
+
423
+ Returns:
424
+ torch.Tensor: D1*...*N*...
425
+ """
426
+ # TODO: now only support N same for every d1
427
+ if len(index.shape) == 1:
428
+ return torch.index_select(tensor, dim=dim, index=index)
429
+ else:
430
+ index = repeat_index_to_target_size(index, tensor.shape[0])
431
+ out = []
432
+ for i in torch.arange(tensor.shape[0]):
433
+ sub_tensor = tensor[i]
434
+ sub_index = index[i]
435
+ d = torch.index_select(sub_tensor, dim=dim - 1, index=sub_index)
436
+ out.append(d)
437
+ return torch.stack(out).to(dtype=tensor.dtype)
438
+
439
+
440
+ def batch_index_copy(
441
+ tensor: torch.Tensor, dim: int, index: torch.LongTensor, source: torch.Tensor
442
+ ) -> torch.Tensor:
443
+ """_summary_
444
+
445
+ Args:
446
+ tensor (torch.Tensor): b*c*h
447
+ dim (int):
448
+ index (torch.LongTensor): b*d,
449
+ source (torch.Tensor):
450
+ b*d*h*..., if dim=1
451
+ b*c*d*..., if dim=2
452
+
453
+ Returns:
454
+ torch.Tensor: b*c*d*...
455
+ """
456
+ if len(index.shape) == 1:
457
+ tensor.index_copy_(dim=dim, index=index, source=source)
458
+ else:
459
+ index = repeat_index_to_target_size(index, tensor.shape[0])
460
+
461
+ batch_size = tensor.shape[0]
462
+ for b in torch.arange(batch_size):
463
+ sub_index = index[b]
464
+ sub_source = source[b]
465
+ sub_tensor = tensor[b]
466
+ sub_tensor.index_copy_(dim=dim - 1, index=sub_index, source=sub_source)
467
+ tensor[b] = sub_tensor
468
+ return tensor
469
+
470
+
471
+ def batch_index_fill(
472
+ tensor: torch.Tensor,
473
+ dim: int,
474
+ index: torch.LongTensor,
475
+ value: Literal[torch.Tensor, torch.float],
476
+ ) -> torch.Tensor:
477
+ """_summary_
478
+
479
+ Args:
480
+ tensor (torch.Tensor): b*c*h
481
+ dim (int):
482
+ index (torch.LongTensor): b*d,
483
+ value (torch.Tensor): b
484
+
485
+ Returns:
486
+ torch.Tensor: b*c*d*...
487
+ """
488
+ index = repeat_index_to_target_size(index, tensor.shape[0])
489
+ batch_size = tensor.shape[0]
490
+ for b in torch.arange(batch_size):
491
+ sub_index = index[b]
492
+ sub_value = value[b] if isinstance(value, torch.Tensor) else value
493
+ sub_tensor = tensor[b]
494
+ sub_tensor.index_fill_(dim - 1, sub_index, sub_value)
495
+ tensor[b] = sub_tensor
496
+ return tensor
497
+
498
+
499
+ def adaptive_instance_normalization(
500
+ src: torch.Tensor,
501
+ dst: torch.Tensor,
502
+ eps: float = 1e-6,
503
+ ):
504
+ """
505
+ Args:
506
+ src (torch.Tensor): b c t h w
507
+ dst (torch.Tensor): b c t h w
508
+ """
509
+ ndim = src.ndim
510
+ if ndim == 5:
511
+ dim = (2, 3, 4)
512
+ elif ndim == 4:
513
+ dim = (2, 3)
514
+ elif ndim == 3:
515
+ dim = 2
516
+ else:
517
+ raise ValueError("only support ndim in [3,4,5], but given {ndim}")
518
+ var, mean = torch.var_mean(src, dim=dim, keepdim=True, correction=0)
519
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
520
+ dst = align_repeat_tensor_single_dim(dst, src.shape[0], dim=0)
521
+ mean_acc, var_acc = torch.var_mean(dst, dim=dim, keepdim=True, correction=0)
522
+ # mean_acc = sum(mean_acc) / float(len(mean_acc))
523
+ # var_acc = sum(var_acc) / float(len(var_acc))
524
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
525
+ src = (((src - mean) / std) * std_acc) + mean_acc
526
+ return src
527
+
528
+
529
+ def adaptive_instance_normalization_with_ref(
530
+ src: torch.LongTensor,
531
+ dst: torch.LongTensor,
532
+ style_fidelity: float = 0.5,
533
+ do_classifier_free_guidance: bool = True,
534
+ ):
535
+ # logger.debug(
536
+ # f"src={src.shape}, min={src.min()}, max={src.max()}, mean={src.mean()}, \n"
537
+ # f"dst={src.shape}, min={dst.min()}, max={dst.max()}, mean={dst.mean()}"
538
+ # )
539
+ batch_size = src.shape[0] // 2
540
+ uc_mask = torch.Tensor([1] * batch_size + [0] * batch_size).type_as(src).bool()
541
+ src_uc = adaptive_instance_normalization(src, dst)
542
+ src_c = src_uc.clone()
543
+ # TODO: 该部分默认 do_classifier_free_guidance and style_fidelity > 0 = True
544
+ if do_classifier_free_guidance and style_fidelity > 0:
545
+ src_c[uc_mask] = src[uc_mask]
546
+ src = style_fidelity * src_c + (1.0 - style_fidelity) * src_uc
547
+ return src
548
+
549
+
550
+ def batch_adain_conditioned_tensor(
551
+ tensor: torch.Tensor,
552
+ src_index: torch.LongTensor,
553
+ dst_index: torch.LongTensor,
554
+ keep_dim: bool = True,
555
+ num_frames: int = None,
556
+ dim: int = 2,
557
+ style_fidelity: float = 0.5,
558
+ do_classifier_free_guidance: bool = True,
559
+ need_style_fidelity: bool = False,
560
+ ):
561
+ """_summary_
562
+
563
+ Args:
564
+ tensor (torch.Tensor): b c t h w
565
+ src_index (torch.LongTensor): _description_
566
+ dst_index (torch.LongTensor): _description_
567
+ keep_dim (bool, optional): _description_. Defaults to True.
568
+
569
+ Returns:
570
+ _type_: _description_
571
+ """
572
+ ndim = tensor.ndim
573
+ dtype = tensor.dtype
574
+ if ndim == 4 and num_frames is not None:
575
+ tensor = rearrange(tensor, "(b t) c h w-> b c t h w ", t=num_frames)
576
+ src = batch_index_select(tensor, dim=dim, index=src_index).contiguous()
577
+ dst = batch_index_select(tensor, dim=dim, index=dst_index).contiguous()
578
+ if need_style_fidelity:
579
+ src = adaptive_instance_normalization_with_ref(
580
+ src=src,
581
+ dst=dst,
582
+ style_fidelity=style_fidelity,
583
+ do_classifier_free_guidance=do_classifier_free_guidance,
584
+ need_style_fidelity=need_style_fidelity,
585
+ )
586
+ else:
587
+ src = adaptive_instance_normalization(
588
+ src=src,
589
+ dst=dst,
590
+ )
591
+ if keep_dim:
592
+ src = batch_concat_two_tensor_with_index(
593
+ src.to(dtype=dtype),
594
+ src_index,
595
+ dst.to(dtype=dtype),
596
+ dst_index,
597
+ dim=dim,
598
+ )
599
+
600
+ if ndim == 4 and num_frames is not None:
601
+ src = rearrange(tensor, "b c t h w ->(b t) c h w")
602
+ return src
603
+
604
+
605
+ def align_repeat_tensor_single_dim(
606
+ src: torch.Tensor,
607
+ target_length: int,
608
+ dim: int = 0,
609
+ n_src_base_length: int = 1,
610
+ src_base_index: List[int] = None,
611
+ ) -> torch.Tensor:
612
+ """沿着 dim 纬度, 补齐 src 的长度到目标 target_length。
613
+ 当 src 长度不如 target_length 时, 取其中 前 n_src_base_length 然后 repeat 到 target_length
614
+
615
+ align length of src to target_length along dim
616
+ when src length is less than target_length, take the first n_src_base_length and repeat to target_length
617
+
618
+ Args:
619
+ src (torch.Tensor): 输入 tensor, input tensor
620
+ target_length (int): 目标长度, target_length
621
+ dim (int, optional): 处理纬度, target dim . Defaults to 0.
622
+ n_src_base_length (int, optional): src 的基本单元长度, basic length of src. Defaults to 1.
623
+
624
+ Returns:
625
+ torch.Tensor: _description_
626
+ """
627
+ src_dim_length = src.shape[dim]
628
+ if target_length > src_dim_length:
629
+ if target_length % src_dim_length == 0:
630
+ new = src.repeat_interleave(
631
+ repeats=target_length // src_dim_length, dim=dim
632
+ )
633
+ else:
634
+ if src_base_index is None and n_src_base_length is not None:
635
+ src_base_index = torch.arange(n_src_base_length)
636
+
637
+ new = src.index_select(
638
+ dim=dim,
639
+ index=torch.LongTensor(src_base_index).to(device=src.device),
640
+ )
641
+ new = new.repeat_interleave(
642
+ repeats=target_length // len(src_base_index),
643
+ dim=dim,
644
+ )
645
+ elif target_length < src_dim_length:
646
+ new = src.index_select(
647
+ dim=dim,
648
+ index=torch.LongTensor(torch.arange(target_length)).to(device=src.device),
649
+ )
650
+ else:
651
+ new = src
652
+ return new
653
+
654
+
655
+ def fuse_part_tensor(
656
+ src: torch.Tensor,
657
+ dst: torch.Tensor,
658
+ overlap: int,
659
+ weight: float = 0.5,
660
+ skip_step: int = 0,
661
+ ) -> torch.Tensor:
662
+ """fuse overstep tensor with weight of src into dst
663
+ out = src_fused_part * weight + dst * (1-weight) for overlap
664
+
665
+ Args:
666
+ src (torch.Tensor): b c t h w
667
+ dst (torch.Tensor): b c t h w
668
+ overlap (int): 1
669
+ weight (float, optional): weight of src tensor part. Defaults to 0.5.
670
+
671
+ Returns:
672
+ torch.Tensor: fused tensor
673
+ """
674
+ if overlap == 0:
675
+ return dst
676
+ else:
677
+ dst[:, :, skip_step : skip_step + overlap] = (
678
+ weight * src[:, :, -overlap:]
679
+ + (1 - weight) * dst[:, :, skip_step : skip_step + overlap]
680
+ )
681
+ return dst
musev/logging.conf ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root,musev
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=musevFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ # logger level 尽量设置低一点
15
+ [logger_musev]
16
+ level=DEBUG
17
+ handlers=consoleHandler
18
+ qualname=musev
19
+ propagate=0
20
+
21
+ # handler level 设置比 logger level高
22
+ [handler_consoleHandler]
23
+ class=StreamHandler
24
+ level=DEBUG
25
+ # level=INFO
26
+
27
+ formatter=musevFormatter
28
+ args=(sys.stdout,)
29
+
30
+ [formatter_musevFormatter]
31
+ format=%(asctime)s- %(name)s:%(lineno)d- %(levelname)s- %(message)s
32
+ datefmt=
musev/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ..utils.register import Register
2
+
3
+ Model_Register = Register(registry_name="torch_model")
musev/models/attention.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Adapted from https://github.com/huggingface/diffusers/blob/64bf5d33b7ef1b1deac256bed7bd99b55020c4e0/src/diffusers/models/attention.py
16
+ from __future__ import annotations
17
+ from copy import deepcopy
18
+
19
+ from typing import Any, Dict, List, Literal, Optional, Callable, Tuple
20
+ import logging
21
+ from einops import rearrange
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import nn
26
+
27
+ from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
28
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
29
+ from diffusers.models.attention_processor import Attention as DiffusersAttention
30
+ from diffusers.models.attention import (
31
+ BasicTransformerBlock as DiffusersBasicTransformerBlock,
32
+ AdaLayerNormZero,
33
+ AdaLayerNorm,
34
+ FeedForward,
35
+ )
36
+ from diffusers.models.attention_processor import AttnProcessor
37
+
38
+ from .attention_processor import IPAttention, BaseIPAttnProcessor
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def not_use_xformers_anyway(
45
+ use_memory_efficient_attention_xformers: bool,
46
+ attention_op: Optional[Callable] = None,
47
+ ):
48
+ return None
49
+
50
+
51
+ @maybe_allow_in_graph
52
+ class BasicTransformerBlock(DiffusersBasicTransformerBlock):
53
+ print_idx = 0
54
+
55
+ def __init__(
56
+ self,
57
+ dim: int,
58
+ num_attention_heads: int,
59
+ attention_head_dim: int,
60
+ dropout=0,
61
+ cross_attention_dim: int | None = None,
62
+ activation_fn: str = "geglu",
63
+ num_embeds_ada_norm: int | None = None,
64
+ attention_bias: bool = False,
65
+ only_cross_attention: bool = False,
66
+ double_self_attention: bool = False,
67
+ upcast_attention: bool = False,
68
+ norm_elementwise_affine: bool = True,
69
+ norm_type: str = "layer_norm",
70
+ final_dropout: bool = False,
71
+ attention_type: str = "default",
72
+ allow_xformers: bool = True,
73
+ cross_attn_temporal_cond: bool = False,
74
+ image_scale: float = 1.0,
75
+ processor: AttnProcessor | None = None,
76
+ ip_adapter_cross_attn: bool = False,
77
+ need_t2i_facein: bool = False,
78
+ need_t2i_ip_adapter_face: bool = False,
79
+ ):
80
+ if not only_cross_attention and double_self_attention:
81
+ cross_attention_dim = None
82
+ super().__init__(
83
+ dim,
84
+ num_attention_heads,
85
+ attention_head_dim,
86
+ dropout,
87
+ cross_attention_dim,
88
+ activation_fn,
89
+ num_embeds_ada_norm,
90
+ attention_bias,
91
+ only_cross_attention,
92
+ double_self_attention,
93
+ upcast_attention,
94
+ norm_elementwise_affine,
95
+ norm_type,
96
+ final_dropout,
97
+ attention_type,
98
+ )
99
+
100
+ self.attn1 = IPAttention(
101
+ query_dim=dim,
102
+ heads=num_attention_heads,
103
+ dim_head=attention_head_dim,
104
+ dropout=dropout,
105
+ bias=attention_bias,
106
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
107
+ upcast_attention=upcast_attention,
108
+ cross_attn_temporal_cond=cross_attn_temporal_cond,
109
+ image_scale=image_scale,
110
+ ip_adapter_dim=cross_attention_dim
111
+ if only_cross_attention
112
+ else attention_head_dim,
113
+ facein_dim=cross_attention_dim
114
+ if only_cross_attention
115
+ else attention_head_dim,
116
+ processor=processor,
117
+ )
118
+ # 2. Cross-Attn
119
+ if cross_attention_dim is not None or double_self_attention:
120
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
121
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
122
+ # the second cross attention block.
123
+ self.norm2 = (
124
+ AdaLayerNorm(dim, num_embeds_ada_norm)
125
+ if self.use_ada_layer_norm
126
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
127
+ )
128
+
129
+ self.attn2 = IPAttention(
130
+ query_dim=dim,
131
+ cross_attention_dim=cross_attention_dim
132
+ if not double_self_attention
133
+ else None,
134
+ heads=num_attention_heads,
135
+ dim_head=attention_head_dim,
136
+ dropout=dropout,
137
+ bias=attention_bias,
138
+ upcast_attention=upcast_attention,
139
+ cross_attn_temporal_cond=ip_adapter_cross_attn,
140
+ need_t2i_facein=need_t2i_facein,
141
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
142
+ image_scale=image_scale,
143
+ ip_adapter_dim=cross_attention_dim
144
+ if not double_self_attention
145
+ else attention_head_dim,
146
+ facein_dim=cross_attention_dim
147
+ if not double_self_attention
148
+ else attention_head_dim,
149
+ ip_adapter_face_dim=cross_attention_dim
150
+ if not double_self_attention
151
+ else attention_head_dim,
152
+ processor=processor,
153
+ ) # is self-attn if encoder_hidden_states is none
154
+ else:
155
+ self.norm2 = None
156
+ self.attn2 = None
157
+ if self.attn1 is not None:
158
+ if not allow_xformers:
159
+ self.attn1.set_use_memory_efficient_attention_xformers = (
160
+ not_use_xformers_anyway
161
+ )
162
+ if self.attn2 is not None:
163
+ if not allow_xformers:
164
+ self.attn2.set_use_memory_efficient_attention_xformers = (
165
+ not_use_xformers_anyway
166
+ )
167
+ self.double_self_attention = double_self_attention
168
+ self.only_cross_attention = only_cross_attention
169
+ self.cross_attn_temporal_cond = cross_attn_temporal_cond
170
+ self.image_scale = image_scale
171
+
172
+ def forward(
173
+ self,
174
+ hidden_states: torch.FloatTensor,
175
+ attention_mask: Optional[torch.FloatTensor] = None,
176
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
177
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
178
+ timestep: Optional[torch.LongTensor] = None,
179
+ cross_attention_kwargs: Dict[str, Any] = None,
180
+ class_labels: Optional[torch.LongTensor] = None,
181
+ self_attn_block_embs: Optional[Tuple[List[torch.Tensor], List[None]]] = None,
182
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
183
+ ) -> torch.FloatTensor:
184
+ # Notice that normalization is always applied before the real computation in the following blocks.
185
+ # 0. Self-Attention
186
+ if self.use_ada_layer_norm:
187
+ norm_hidden_states = self.norm1(hidden_states, timestep)
188
+ elif self.use_ada_layer_norm_zero:
189
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
190
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
191
+ )
192
+ else:
193
+ norm_hidden_states = self.norm1(hidden_states)
194
+
195
+ # 1. Retrieve lora scale.
196
+ lora_scale = (
197
+ cross_attention_kwargs.get("scale", 1.0)
198
+ if cross_attention_kwargs is not None
199
+ else 1.0
200
+ )
201
+
202
+ if cross_attention_kwargs is None:
203
+ cross_attention_kwargs = {}
204
+ # 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备
205
+ # special AttnProcessor needs input parameters in cross_attention_kwargs
206
+ original_cross_attention_kwargs = {
207
+ k: v
208
+ for k, v in cross_attention_kwargs.items()
209
+ if k
210
+ not in [
211
+ "num_frames",
212
+ "sample_index",
213
+ "vision_conditon_frames_sample_index",
214
+ "vision_cond",
215
+ "vision_clip_emb",
216
+ "ip_adapter_scale",
217
+ "face_emb",
218
+ "facein_scale",
219
+ "ip_adapter_face_emb",
220
+ "ip_adapter_face_scale",
221
+ "do_classifier_free_guidance",
222
+ ]
223
+ }
224
+
225
+ if "do_classifier_free_guidance" in cross_attention_kwargs:
226
+ do_classifier_free_guidance = cross_attention_kwargs[
227
+ "do_classifier_free_guidance"
228
+ ]
229
+ else:
230
+ do_classifier_free_guidance = False
231
+
232
+ # 2. Prepare GLIGEN inputs
233
+ original_cross_attention_kwargs = (
234
+ original_cross_attention_kwargs.copy()
235
+ if original_cross_attention_kwargs is not None
236
+ else {}
237
+ )
238
+ gligen_kwargs = original_cross_attention_kwargs.pop("gligen", None)
239
+
240
+ # 返回self_attn的结果,适用于referencenet的输出给其他Unet来使用
241
+ # return the result of self_attn, which is suitable for the output of referencenet to be used by other Unet
242
+ if (
243
+ self_attn_block_embs is not None
244
+ and self_attn_block_embs_mode.lower() == "write"
245
+ ):
246
+ # self_attn_block_emb = self.attn1.head_to_batch_dim(attn_output, out_dim=4)
247
+ self_attn_block_emb = norm_hidden_states
248
+ if not hasattr(self, "spatial_self_attn_idx"):
249
+ raise ValueError(
250
+ "must call unet.insert_spatial_self_attn_idx to generate spatial attn index"
251
+ )
252
+ basick_transformer_idx = self.spatial_self_attn_idx
253
+ if self.print_idx == 0:
254
+ logger.debug(
255
+ f"self_attn_block_embs, self_attn_block_embs_mode={self_attn_block_embs_mode}, "
256
+ f"basick_transformer_idx={basick_transformer_idx}, length={len(self_attn_block_embs)}, shape={self_attn_block_emb.shape}, "
257
+ # f"attn1 processor, {type(self.attn1.processor)}"
258
+ )
259
+ self_attn_block_embs[basick_transformer_idx] = self_attn_block_emb
260
+
261
+ # read and put referencenet emb into cross_attention_kwargs, which would be fused into attn_processor
262
+ if (
263
+ self_attn_block_embs is not None
264
+ and self_attn_block_embs_mode.lower() == "read"
265
+ ):
266
+ basick_transformer_idx = self.spatial_self_attn_idx
267
+ if not hasattr(self, "spatial_self_attn_idx"):
268
+ raise ValueError(
269
+ "must call unet.insert_spatial_self_attn_idx to generate spatial attn index"
270
+ )
271
+ if self.print_idx == 0:
272
+ logger.debug(
273
+ f"refer_self_attn_emb: , self_attn_block_embs_mode={self_attn_block_embs_mode}, "
274
+ f"length={len(self_attn_block_embs)}, idx={basick_transformer_idx}, "
275
+ # f"attn1 processor, {type(self.attn1.processor)}, "
276
+ )
277
+ ref_emb = self_attn_block_embs[basick_transformer_idx]
278
+ cross_attention_kwargs["refer_emb"] = ref_emb
279
+ if self.print_idx == 0:
280
+ logger.debug(
281
+ f"unet attention read, {self.spatial_self_attn_idx}",
282
+ )
283
+ # ------------------------------warning-----------------------
284
+ # 这两行由于使用了ref_emb会导致和checkpoint_train相关的训练错误,具体未知,留在这里作为警示
285
+ # bellow annoated code will cause training error, keep it here as a warning
286
+ # logger.debug(f"ref_emb shape,{ref_emb.shape}, {ref_emb.mean()}")
287
+ # logger.debug(
288
+ # f"norm_hidden_states shape, {norm_hidden_states.shape}, {norm_hidden_states.mean()}",
289
+ # )
290
+ if self.attn1 is None:
291
+ self.print_idx += 1
292
+ return norm_hidden_states
293
+ attn_output = self.attn1(
294
+ norm_hidden_states,
295
+ encoder_hidden_states=encoder_hidden_states
296
+ if self.only_cross_attention
297
+ else None,
298
+ attention_mask=attention_mask,
299
+ **(
300
+ cross_attention_kwargs
301
+ if isinstance(self.attn1.processor, BaseIPAttnProcessor)
302
+ else original_cross_attention_kwargs
303
+ ),
304
+ )
305
+
306
+ if self.use_ada_layer_norm_zero:
307
+ attn_output = gate_msa.unsqueeze(1) * attn_output
308
+ hidden_states = attn_output + hidden_states
309
+
310
+ # 推断的时候,对于uncondition_部分独立生成,排除掉 refer_emb,
311
+ # 首帧等的影响,避免生成参考了refer_emb、首帧等,又在uncond上去除了
312
+ # in inference stage, eliminate influence of refer_emb, vis_cond on unconditionpart
313
+ # to avoid use that, and then eliminate in pipeline
314
+ # refer to moore-animate anyone
315
+
316
+ # do_classifier_free_guidance = False
317
+ if self.print_idx == 0:
318
+ logger.debug(f"do_classifier_free_guidance={do_classifier_free_guidance},")
319
+ if do_classifier_free_guidance:
320
+ hidden_states_c = attn_output.clone()
321
+ _uc_mask = (
322
+ torch.Tensor(
323
+ [1] * (norm_hidden_states.shape[0] // 2)
324
+ + [0] * (norm_hidden_states.shape[0] // 2)
325
+ )
326
+ .to(norm_hidden_states.device)
327
+ .bool()
328
+ )
329
+ hidden_states_c[_uc_mask] = self.attn1(
330
+ norm_hidden_states[_uc_mask],
331
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
332
+ attention_mask=attention_mask,
333
+ )
334
+ attn_output = hidden_states_c.clone()
335
+
336
+ if "refer_emb" in cross_attention_kwargs:
337
+ del cross_attention_kwargs["refer_emb"]
338
+
339
+ # 2.5 GLIGEN Control
340
+ if gligen_kwargs is not None:
341
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
342
+ # 2.5 ends
343
+
344
+ # 3. Cross-Attention
345
+ if self.attn2 is not None:
346
+ norm_hidden_states = (
347
+ self.norm2(hidden_states, timestep)
348
+ if self.use_ada_layer_norm
349
+ else self.norm2(hidden_states)
350
+ )
351
+
352
+ # 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备
353
+ # special AttnProcessor needs input parameters in cross_attention_kwargs
354
+ attn_output = self.attn2(
355
+ norm_hidden_states,
356
+ encoder_hidden_states=encoder_hidden_states
357
+ if not self.double_self_attention
358
+ else None,
359
+ attention_mask=encoder_attention_mask,
360
+ **(
361
+ original_cross_attention_kwargs
362
+ if not isinstance(self.attn2.processor, BaseIPAttnProcessor)
363
+ else cross_attention_kwargs
364
+ ),
365
+ )
366
+ if self.print_idx == 0:
367
+ logger.debug(
368
+ f"encoder_hidden_states, type={type(encoder_hidden_states)}"
369
+ )
370
+ if encoder_hidden_states is not None:
371
+ logger.debug(
372
+ f"encoder_hidden_states, ={encoder_hidden_states.shape}"
373
+ )
374
+
375
+ # encoder_hidden_states_tmp = (
376
+ # encoder_hidden_states
377
+ # if not self.double_self_attention
378
+ # else norm_hidden_states
379
+ # )
380
+ # if do_classifier_free_guidance:
381
+ # hidden_states_c = attn_output.clone()
382
+ # _uc_mask = (
383
+ # torch.Tensor(
384
+ # [1] * (norm_hidden_states.shape[0] // 2)
385
+ # + [0] * (norm_hidden_states.shape[0] // 2)
386
+ # )
387
+ # .to(norm_hidden_states.device)
388
+ # .bool()
389
+ # )
390
+ # hidden_states_c[_uc_mask] = self.attn2(
391
+ # norm_hidden_states[_uc_mask],
392
+ # encoder_hidden_states=encoder_hidden_states_tmp[_uc_mask],
393
+ # attention_mask=attention_mask,
394
+ # )
395
+ # attn_output = hidden_states_c.clone()
396
+ hidden_states = attn_output + hidden_states
397
+ # 4. Feed-forward
398
+ if self.norm3 is not None and self.ff is not None:
399
+ norm_hidden_states = self.norm3(hidden_states)
400
+ if self.use_ada_layer_norm_zero:
401
+ norm_hidden_states = (
402
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
403
+ )
404
+ if self._chunk_size is not None:
405
+ # "feed_forward_chunk_size" can be used to save memory
406
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
407
+ raise ValueError(
408
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
409
+ )
410
+
411
+ num_chunks = (
412
+ norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
413
+ )
414
+ ff_output = torch.cat(
415
+ [
416
+ self.ff(hid_slice, scale=lora_scale)
417
+ for hid_slice in norm_hidden_states.chunk(
418
+ num_chunks, dim=self._chunk_dim
419
+ )
420
+ ],
421
+ dim=self._chunk_dim,
422
+ )
423
+ else:
424
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
425
+
426
+ if self.use_ada_layer_norm_zero:
427
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
428
+
429
+ hidden_states = ff_output + hidden_states
430
+ self.print_idx += 1
431
+ return hidden_states
musev/models/attention_processor.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """该模型是自定义的attn_processor,实现特殊功能的 Attn功能。
16
+ 相对而言,开源代码经常会重新定义Attention 类,
17
+
18
+ This module implements special AttnProcessor function with custom attn_processor class.
19
+ While other open source code always modify Attention class.
20
+ """
21
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
22
+ from __future__ import annotations
23
+
24
+ import time
25
+ from typing import Any, Callable, Optional
26
+ import logging
27
+
28
+ from einops import rearrange, repeat
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ import xformers
33
+ from diffusers.models.lora import LoRACompatibleLinear
34
+
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from diffusers.models.attention_processor import (
37
+ Attention as DiffusersAttention,
38
+ AttnProcessor,
39
+ AttnProcessor2_0,
40
+ )
41
+ from ..data.data_util import (
42
+ batch_concat_two_tensor_with_index,
43
+ batch_index_select,
44
+ align_repeat_tensor_single_dim,
45
+ batch_adain_conditioned_tensor,
46
+ )
47
+
48
+ from . import Model_Register
49
+
50
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
51
+
52
+
53
+ @maybe_allow_in_graph
54
+ class IPAttention(DiffusersAttention):
55
+ r"""
56
+ Modified Attention class which has special layer, like ip_apadapter_to_k, ip_apadapter_to_v,
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ query_dim: int,
62
+ cross_attention_dim: int | None = None,
63
+ heads: int = 8,
64
+ dim_head: int = 64,
65
+ dropout: float = 0,
66
+ bias=False,
67
+ upcast_attention: bool = False,
68
+ upcast_softmax: bool = False,
69
+ cross_attention_norm: str | None = None,
70
+ cross_attention_norm_num_groups: int = 32,
71
+ added_kv_proj_dim: int | None = None,
72
+ norm_num_groups: int | None = None,
73
+ spatial_norm_dim: int | None = None,
74
+ out_bias: bool = True,
75
+ scale_qk: bool = True,
76
+ only_cross_attention: bool = False,
77
+ eps: float = 0.00001,
78
+ rescale_output_factor: float = 1,
79
+ residual_connection: bool = False,
80
+ _from_deprecated_attn_block=False,
81
+ processor: AttnProcessor | None = None,
82
+ cross_attn_temporal_cond: bool = False,
83
+ image_scale: float = 1.0,
84
+ ip_adapter_dim: int = None,
85
+ need_t2i_facein: bool = False,
86
+ facein_dim: int = None,
87
+ need_t2i_ip_adapter_face: bool = False,
88
+ ip_adapter_face_dim: int = None,
89
+ ):
90
+ super().__init__(
91
+ query_dim,
92
+ cross_attention_dim,
93
+ heads,
94
+ dim_head,
95
+ dropout,
96
+ bias,
97
+ upcast_attention,
98
+ upcast_softmax,
99
+ cross_attention_norm,
100
+ cross_attention_norm_num_groups,
101
+ added_kv_proj_dim,
102
+ norm_num_groups,
103
+ spatial_norm_dim,
104
+ out_bias,
105
+ scale_qk,
106
+ only_cross_attention,
107
+ eps,
108
+ rescale_output_factor,
109
+ residual_connection,
110
+ _from_deprecated_attn_block,
111
+ processor,
112
+ )
113
+ self.cross_attn_temporal_cond = cross_attn_temporal_cond
114
+ self.image_scale = image_scale
115
+ # 面向首帧的 ip_adapter
116
+ # ip_apdater
117
+ if cross_attn_temporal_cond:
118
+ self.to_k_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
119
+ self.to_v_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
120
+ # facein
121
+ self.need_t2i_facein = need_t2i_facein
122
+ self.facein_dim = facein_dim
123
+ if need_t2i_facein:
124
+ raise NotImplementedError("facein")
125
+
126
+ # ip_adapter_face
127
+ self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
128
+ self.ip_adapter_face_dim = ip_adapter_face_dim
129
+ if need_t2i_ip_adapter_face:
130
+ self.ip_adapter_face_to_k_ip = LoRACompatibleLinear(
131
+ ip_adapter_face_dim, query_dim, bias=False
132
+ )
133
+ self.ip_adapter_face_to_v_ip = LoRACompatibleLinear(
134
+ ip_adapter_face_dim, query_dim, bias=False
135
+ )
136
+
137
+ def set_use_memory_efficient_attention_xformers(
138
+ self,
139
+ use_memory_efficient_attention_xformers: bool,
140
+ attention_op: Callable[..., Any] | None = None,
141
+ ):
142
+ if (
143
+ "XFormers" in self.processor.__class__.__name__
144
+ or "IP" in self.processor.__class__.__name__
145
+ ):
146
+ pass
147
+ else:
148
+ return super().set_use_memory_efficient_attention_xformers(
149
+ use_memory_efficient_attention_xformers, attention_op
150
+ )
151
+
152
+
153
+ @Model_Register.register
154
+ class BaseIPAttnProcessor(nn.Module):
155
+ print_idx = 0
156
+
157
+ def __init__(self, *args, **kwargs) -> None:
158
+ super().__init__(*args, **kwargs)
159
+
160
+
161
+ @Model_Register.register
162
+ class T2IReferencenetIPAdapterXFormersAttnProcessor(BaseIPAttnProcessor):
163
+ r"""
164
+ 面向 ref_image的 self_attn的 IPAdapter
165
+ """
166
+ print_idx = 0
167
+
168
+ def __init__(
169
+ self,
170
+ attention_op: Optional[Callable] = None,
171
+ ):
172
+ super().__init__()
173
+
174
+ self.attention_op = attention_op
175
+
176
+ def __call__(
177
+ self,
178
+ attn: IPAttention,
179
+ hidden_states: torch.FloatTensor,
180
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ temb: Optional[torch.FloatTensor] = None,
183
+ scale: float = 1.0,
184
+ num_frames: int = None,
185
+ sample_index: torch.LongTensor = None,
186
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
187
+ refer_emb: torch.Tensor = None,
188
+ vision_clip_emb: torch.Tensor = None,
189
+ ip_adapter_scale: float = 1.0,
190
+ face_emb: torch.Tensor = None,
191
+ facein_scale: float = 1.0,
192
+ ip_adapter_face_emb: torch.Tensor = None,
193
+ ip_adapter_face_scale: float = 1.0,
194
+ do_classifier_free_guidance: bool = False,
195
+ ):
196
+ residual = hidden_states
197
+
198
+ if attn.spatial_norm is not None:
199
+ hidden_states = attn.spatial_norm(hidden_states, temb)
200
+
201
+ input_ndim = hidden_states.ndim
202
+
203
+ if input_ndim == 4:
204
+ batch_size, channel, height, width = hidden_states.shape
205
+ hidden_states = hidden_states.view(
206
+ batch_size, channel, height * width
207
+ ).transpose(1, 2)
208
+
209
+ batch_size, key_tokens, _ = (
210
+ hidden_states.shape
211
+ if encoder_hidden_states is None
212
+ else encoder_hidden_states.shape
213
+ )
214
+
215
+ attention_mask = attn.prepare_attention_mask(
216
+ attention_mask, key_tokens, batch_size
217
+ )
218
+ if attention_mask is not None:
219
+ # expand our mask's singleton query_tokens dimension:
220
+ # [batch*heads, 1, key_tokens] ->
221
+ # [batch*heads, query_tokens, key_tokens]
222
+ # so that it can be added as a bias onto the attention scores that xformers computes:
223
+ # [batch*heads, query_tokens, key_tokens]
224
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
225
+ _, query_tokens, _ = hidden_states.shape
226
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
227
+
228
+ if attn.group_norm is not None:
229
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
230
+ 1, 2
231
+ )
232
+
233
+ query = attn.to_q(hidden_states, scale=scale)
234
+
235
+ if encoder_hidden_states is None:
236
+ encoder_hidden_states = hidden_states
237
+ elif attn.norm_cross:
238
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
239
+ encoder_hidden_states
240
+ )
241
+ encoder_hidden_states = align_repeat_tensor_single_dim(
242
+ encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
243
+ )
244
+ key = attn.to_k(encoder_hidden_states, scale=scale)
245
+ value = attn.to_v(encoder_hidden_states, scale=scale)
246
+
247
+ # for facein
248
+ if self.print_idx == 0:
249
+ logger.debug(
250
+ f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(face_emb)={type(face_emb)}, facein_scale={facein_scale}"
251
+ )
252
+ if facein_scale > 0 and face_emb is not None:
253
+ raise NotImplementedError("facein")
254
+
255
+ query = attn.head_to_batch_dim(query).contiguous()
256
+ key = attn.head_to_batch_dim(key).contiguous()
257
+ value = attn.head_to_batch_dim(value).contiguous()
258
+ hidden_states = xformers.ops.memory_efficient_attention(
259
+ query,
260
+ key,
261
+ value,
262
+ attn_bias=attention_mask,
263
+ op=self.attention_op,
264
+ scale=attn.scale,
265
+ )
266
+
267
+ # ip-adapter start
268
+ if self.print_idx == 0:
269
+ logger.debug(
270
+ f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(vision_clip_emb)={type(vision_clip_emb)}"
271
+ )
272
+ if ip_adapter_scale > 0 and vision_clip_emb is not None:
273
+ if self.print_idx == 0:
274
+ logger.debug(
275
+ f"T2I cross_attn, ipadapter, vision_clip_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
276
+ )
277
+ ip_key = attn.to_k_ip(vision_clip_emb)
278
+ ip_value = attn.to_v_ip(vision_clip_emb)
279
+ ip_key = align_repeat_tensor_single_dim(
280
+ ip_key, target_length=batch_size, dim=0
281
+ )
282
+ ip_value = align_repeat_tensor_single_dim(
283
+ ip_value, target_length=batch_size, dim=0
284
+ )
285
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
286
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
287
+ if self.print_idx == 0:
288
+ logger.debug(
289
+ f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
290
+ )
291
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
292
+ hidden_states_from_ip = xformers.ops.memory_efficient_attention(
293
+ query,
294
+ ip_key,
295
+ ip_value,
296
+ attn_bias=attention_mask,
297
+ op=self.attention_op,
298
+ scale=attn.scale,
299
+ )
300
+ hidden_states = hidden_states + ip_adapter_scale * hidden_states_from_ip
301
+ # ip-adapter end
302
+
303
+ # ip-adapter face start
304
+ if self.print_idx == 0:
305
+ logger.debug(
306
+ f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(ip_adapter_face_emb)={type(ip_adapter_face_emb)}"
307
+ )
308
+ if ip_adapter_face_scale > 0 and ip_adapter_face_emb is not None:
309
+ if self.print_idx == 0:
310
+ logger.debug(
311
+ f"T2I cross_attn, ipadapter face, ip_adapter_face_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
312
+ )
313
+ ip_key = attn.ip_adapter_face_to_k_ip(ip_adapter_face_emb)
314
+ ip_value = attn.ip_adapter_face_to_v_ip(ip_adapter_face_emb)
315
+ ip_key = align_repeat_tensor_single_dim(
316
+ ip_key, target_length=batch_size, dim=0
317
+ )
318
+ ip_value = align_repeat_tensor_single_dim(
319
+ ip_value, target_length=batch_size, dim=0
320
+ )
321
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
322
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
323
+ if self.print_idx == 0:
324
+ logger.debug(
325
+ f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
326
+ )
327
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
328
+ hidden_states_from_ip = xformers.ops.memory_efficient_attention(
329
+ query,
330
+ ip_key,
331
+ ip_value,
332
+ attn_bias=attention_mask,
333
+ op=self.attention_op,
334
+ scale=attn.scale,
335
+ )
336
+ hidden_states = (
337
+ hidden_states + ip_adapter_face_scale * hidden_states_from_ip
338
+ )
339
+ # ip-adapter face end
340
+
341
+ hidden_states = hidden_states.to(query.dtype)
342
+ hidden_states = attn.batch_to_head_dim(hidden_states)
343
+
344
+ # linear proj
345
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
346
+ # dropout
347
+ hidden_states = attn.to_out[1](hidden_states)
348
+
349
+ if input_ndim == 4:
350
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
351
+ batch_size, channel, height, width
352
+ )
353
+
354
+ if attn.residual_connection:
355
+ hidden_states = hidden_states + residual
356
+
357
+ hidden_states = hidden_states / attn.rescale_output_factor
358
+ self.print_idx += 1
359
+ return hidden_states
360
+
361
+
362
+ @Model_Register.register
363
+ class NonParamT2ISelfReferenceXFormersAttnProcessor(BaseIPAttnProcessor):
364
+ r"""
365
+ 面向首帧的 referenceonly attn,适用于 T2I的 self_attn
366
+ referenceonly with vis_cond as key, value, in t2i self_attn.
367
+ """
368
+ print_idx = 0
369
+
370
+ def __init__(
371
+ self,
372
+ attention_op: Optional[Callable] = None,
373
+ ):
374
+ super().__init__()
375
+
376
+ self.attention_op = attention_op
377
+
378
+ def __call__(
379
+ self,
380
+ attn: IPAttention,
381
+ hidden_states: torch.FloatTensor,
382
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
383
+ attention_mask: Optional[torch.FloatTensor] = None,
384
+ temb: Optional[torch.FloatTensor] = None,
385
+ scale: float = 1.0,
386
+ num_frames: int = None,
387
+ sample_index: torch.LongTensor = None,
388
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
389
+ refer_emb: torch.Tensor = None,
390
+ face_emb: torch.Tensor = None,
391
+ vision_clip_emb: torch.Tensor = None,
392
+ ip_adapter_scale: float = 1.0,
393
+ facein_scale: float = 1.0,
394
+ ip_adapter_face_emb: torch.Tensor = None,
395
+ ip_adapter_face_scale: float = 1.0,
396
+ do_classifier_free_guidance: bool = False,
397
+ ):
398
+ residual = hidden_states
399
+
400
+ if attn.spatial_norm is not None:
401
+ hidden_states = attn.spatial_norm(hidden_states, temb)
402
+
403
+ input_ndim = hidden_states.ndim
404
+
405
+ if input_ndim == 4:
406
+ batch_size, channel, height, width = hidden_states.shape
407
+ hidden_states = hidden_states.view(
408
+ batch_size, channel, height * width
409
+ ).transpose(1, 2)
410
+
411
+ batch_size, key_tokens, _ = (
412
+ hidden_states.shape
413
+ if encoder_hidden_states is None
414
+ else encoder_hidden_states.shape
415
+ )
416
+
417
+ attention_mask = attn.prepare_attention_mask(
418
+ attention_mask, key_tokens, batch_size
419
+ )
420
+ if attention_mask is not None:
421
+ # expand our mask's singleton query_tokens dimension:
422
+ # [batch*heads, 1, key_tokens] ->
423
+ # [batch*heads, query_tokens, key_tokens]
424
+ # so that it can be added as a bias onto the attention scores that xformers computes:
425
+ # [batch*heads, query_tokens, key_tokens]
426
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
427
+ _, query_tokens, _ = hidden_states.shape
428
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
429
+
430
+ # vision_cond in same unet attn start
431
+ if (
432
+ vision_conditon_frames_sample_index is not None and num_frames > 1
433
+ ) or refer_emb is not None:
434
+ batchsize_timesize = hidden_states.shape[0]
435
+ if self.print_idx == 0:
436
+ logger.debug(
437
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor 0, hidden_states={hidden_states.shape}, vision_conditon_frames_sample_index={vision_conditon_frames_sample_index}"
438
+ )
439
+ encoder_hidden_states = rearrange(
440
+ hidden_states, "(b t) hw c -> b t hw c", t=num_frames
441
+ )
442
+ # if False:
443
+ if vision_conditon_frames_sample_index is not None and num_frames > 1:
444
+ ip_hidden_states = batch_index_select(
445
+ encoder_hidden_states,
446
+ dim=1,
447
+ index=vision_conditon_frames_sample_index,
448
+ ).contiguous()
449
+ if self.print_idx == 0:
450
+ logger.debug(
451
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor 1, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
452
+ )
453
+ #
454
+ ip_hidden_states = rearrange(
455
+ ip_hidden_states, "b t hw c -> b 1 (t hw) c"
456
+ )
457
+ ip_hidden_states = align_repeat_tensor_single_dim(
458
+ ip_hidden_states,
459
+ dim=1,
460
+ target_length=num_frames,
461
+ )
462
+ # b t hw c -> b t hw + hw c
463
+ if self.print_idx == 0:
464
+ logger.debug(
465
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor 2, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
466
+ )
467
+ encoder_hidden_states = torch.concat(
468
+ [encoder_hidden_states, ip_hidden_states], dim=2
469
+ )
470
+ if self.print_idx == 0:
471
+ logger.debug(
472
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor 3, hidden_states={hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
473
+ )
474
+ # if False:
475
+ if refer_emb is not None: # and num_frames > 1:
476
+ refer_emb = rearrange(refer_emb, "b c t h w->b 1 (t h w) c")
477
+ refer_emb = align_repeat_tensor_single_dim(
478
+ refer_emb, target_length=num_frames, dim=1
479
+ )
480
+ if self.print_idx == 0:
481
+ logger.debug(
482
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor4, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
483
+ )
484
+ encoder_hidden_states = torch.concat(
485
+ [encoder_hidden_states, refer_emb], dim=2
486
+ )
487
+ if self.print_idx == 0:
488
+ logger.debug(
489
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor5, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
490
+ )
491
+ encoder_hidden_states = rearrange(
492
+ encoder_hidden_states, "b t hw c -> (b t) hw c"
493
+ )
494
+ # vision_cond in same unet attn end
495
+
496
+ if attn.group_norm is not None:
497
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
498
+ 1, 2
499
+ )
500
+
501
+ query = attn.to_q(hidden_states, scale=scale)
502
+
503
+ if encoder_hidden_states is None:
504
+ encoder_hidden_states = hidden_states
505
+ elif attn.norm_cross:
506
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
507
+ encoder_hidden_states
508
+ )
509
+ encoder_hidden_states = align_repeat_tensor_single_dim(
510
+ encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
511
+ )
512
+ key = attn.to_k(encoder_hidden_states, scale=scale)
513
+ value = attn.to_v(encoder_hidden_states, scale=scale)
514
+
515
+ query = attn.head_to_batch_dim(query).contiguous()
516
+ key = attn.head_to_batch_dim(key).contiguous()
517
+ value = attn.head_to_batch_dim(value).contiguous()
518
+
519
+ hidden_states = xformers.ops.memory_efficient_attention(
520
+ query,
521
+ key,
522
+ value,
523
+ attn_bias=attention_mask,
524
+ op=self.attention_op,
525
+ scale=attn.scale,
526
+ )
527
+ hidden_states = hidden_states.to(query.dtype)
528
+ hidden_states = attn.batch_to_head_dim(hidden_states)
529
+
530
+ # linear proj
531
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
532
+ # dropout
533
+ hidden_states = attn.to_out[1](hidden_states)
534
+
535
+ if input_ndim == 4:
536
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
537
+ batch_size, channel, height, width
538
+ )
539
+
540
+ if attn.residual_connection:
541
+ hidden_states = hidden_states + residual
542
+
543
+ hidden_states = hidden_states / attn.rescale_output_factor
544
+ self.print_idx += 1
545
+
546
+ return hidden_states
547
+
548
+
549
+ @Model_Register.register
550
+ class NonParamReferenceIPXFormersAttnProcessor(
551
+ NonParamT2ISelfReferenceXFormersAttnProcessor
552
+ ):
553
+ def __init__(self, attention_op: Callable[..., Any] | None = None):
554
+ super().__init__(attention_op)
555
+
556
+
557
+ @maybe_allow_in_graph
558
+ class ReferEmbFuseAttention(IPAttention):
559
+ """使用 attention 融合 refernet 中的 emb 到 unet 对应的 latens 中
560
+ # TODO: 目前只支持 bt hw c 的融合,后续考虑增加对 视频 bhw t c、b thw c的融合
561
+ residual_connection: bool = True, 默认, 从不产生影响开始学习
562
+
563
+ use attention to fuse referencenet emb into unet latents
564
+ # TODO: by now, only support bt hw c, later consider to support bhw t c, b thw c
565
+ residual_connection: bool = True, default, start from no effect
566
+
567
+ Args:
568
+ IPAttention (_type_): _description_
569
+ """
570
+
571
+ print_idx = 0
572
+
573
+ def __init__(
574
+ self,
575
+ query_dim: int,
576
+ cross_attention_dim: int | None = None,
577
+ heads: int = 8,
578
+ dim_head: int = 64,
579
+ dropout: float = 0,
580
+ bias=False,
581
+ upcast_attention: bool = False,
582
+ upcast_softmax: bool = False,
583
+ cross_attention_norm: str | None = None,
584
+ cross_attention_norm_num_groups: int = 32,
585
+ added_kv_proj_dim: int | None = None,
586
+ norm_num_groups: int | None = None,
587
+ spatial_norm_dim: int | None = None,
588
+ out_bias: bool = True,
589
+ scale_qk: bool = True,
590
+ only_cross_attention: bool = False,
591
+ eps: float = 0.00001,
592
+ rescale_output_factor: float = 1,
593
+ residual_connection: bool = True,
594
+ _from_deprecated_attn_block=False,
595
+ processor: AttnProcessor | None = None,
596
+ cross_attn_temporal_cond: bool = False,
597
+ image_scale: float = 1,
598
+ ):
599
+ super().__init__(
600
+ query_dim,
601
+ cross_attention_dim,
602
+ heads,
603
+ dim_head,
604
+ dropout,
605
+ bias,
606
+ upcast_attention,
607
+ upcast_softmax,
608
+ cross_attention_norm,
609
+ cross_attention_norm_num_groups,
610
+ added_kv_proj_dim,
611
+ norm_num_groups,
612
+ spatial_norm_dim,
613
+ out_bias,
614
+ scale_qk,
615
+ only_cross_attention,
616
+ eps,
617
+ rescale_output_factor,
618
+ residual_connection,
619
+ _from_deprecated_attn_block,
620
+ processor,
621
+ cross_attn_temporal_cond,
622
+ image_scale,
623
+ )
624
+ self.processor = None
625
+ # 配合residual,使一开始不影响之前结果
626
+ nn.init.zeros_(self.to_out[0].weight)
627
+ nn.init.zeros_(self.to_out[0].bias)
628
+
629
+ def forward(
630
+ self,
631
+ hidden_states: torch.FloatTensor,
632
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
633
+ attention_mask: Optional[torch.FloatTensor] = None,
634
+ temb: Optional[torch.FloatTensor] = None,
635
+ scale: float = 1.0,
636
+ num_frames: int = None,
637
+ ) -> torch.Tensor:
638
+ """fuse referencenet emb b c t2 h2 w2 into unet latents b c t1 h1 w1 with attn
639
+ refer to musev/models/attention_processor.py::NonParamT2ISelfReferenceXFormersAttnProcessor
640
+
641
+ Args:
642
+ hidden_states (torch.FloatTensor): unet latents, (b t1) c h1 w1
643
+ encoder_hidden_states (Optional[torch.FloatTensor], optional): referencenet emb b c2 t2 h2 w2. Defaults to None.
644
+ attention_mask (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
645
+ temb (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
646
+ scale (float, optional): _description_. Defaults to 1.0.
647
+ num_frames (int, optional): _description_. Defaults to None.
648
+
649
+ Returns:
650
+ torch.Tensor: _description_
651
+ """
652
+ residual = hidden_states
653
+ # start
654
+ hidden_states = rearrange(
655
+ hidden_states, "(b t) c h w -> b c t h w", t=num_frames
656
+ )
657
+ batch_size, channel, t1, height, width = hidden_states.shape
658
+ if self.print_idx == 0:
659
+ logger.debug(
660
+ f"hidden_states={hidden_states.shape},encoder_hidden_states={encoder_hidden_states.shape}"
661
+ )
662
+ # concat with hidden_states b c t1 h1 w1 in hw channel into bt (t2 + 1)hw c
663
+ encoder_hidden_states = rearrange(
664
+ encoder_hidden_states, " b c t2 h w-> b (t2 h w) c"
665
+ )
666
+ encoder_hidden_states = repeat(
667
+ encoder_hidden_states, " b t2hw c -> (b t) t2hw c", t=t1
668
+ )
669
+ hidden_states = rearrange(hidden_states, " b c t h w-> (b t) (h w) c")
670
+ # bt (t2+1)hw d
671
+ encoder_hidden_states = torch.concat(
672
+ [encoder_hidden_states, hidden_states], dim=1
673
+ )
674
+ # encoder_hidden_states = align_repeat_tensor_single_dim(
675
+ # encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
676
+ # )
677
+ # end
678
+
679
+ if self.spatial_norm is not None:
680
+ hidden_states = self.spatial_norm(hidden_states, temb)
681
+
682
+ _, key_tokens, _ = (
683
+ hidden_states.shape
684
+ if encoder_hidden_states is None
685
+ else encoder_hidden_states.shape
686
+ )
687
+
688
+ attention_mask = self.prepare_attention_mask(
689
+ attention_mask, key_tokens, batch_size
690
+ )
691
+ if attention_mask is not None:
692
+ # expand our mask's singleton query_tokens dimension:
693
+ # [batch*heads, 1, key_tokens] ->
694
+ # [batch*heads, query_tokens, key_tokens]
695
+ # so that it can be added as a bias onto the attention scores that xformers computes:
696
+ # [batch*heads, query_tokens, key_tokens]
697
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
698
+ _, query_tokens, _ = hidden_states.shape
699
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
700
+
701
+ if self.group_norm is not None:
702
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(
703
+ 1, 2
704
+ )
705
+
706
+ query = self.to_q(hidden_states, scale=scale)
707
+
708
+ if encoder_hidden_states is None:
709
+ encoder_hidden_states = hidden_states
710
+ elif self.norm_cross:
711
+ encoder_hidden_states = self.norm_encoder_hidden_states(
712
+ encoder_hidden_states
713
+ )
714
+
715
+ key = self.to_k(encoder_hidden_states, scale=scale)
716
+ value = self.to_v(encoder_hidden_states, scale=scale)
717
+
718
+ query = self.head_to_batch_dim(query).contiguous()
719
+ key = self.head_to_batch_dim(key).contiguous()
720
+ value = self.head_to_batch_dim(value).contiguous()
721
+
722
+ # query: b t hw d
723
+ # key/value: bt (t1+1)hw d
724
+ hidden_states = xformers.ops.memory_efficient_attention(
725
+ query,
726
+ key,
727
+ value,
728
+ attn_bias=attention_mask,
729
+ scale=self.scale,
730
+ )
731
+ hidden_states = hidden_states.to(query.dtype)
732
+ hidden_states = self.batch_to_head_dim(hidden_states)
733
+
734
+ # linear proj
735
+ hidden_states = self.to_out[0](hidden_states, scale=scale)
736
+ # dropout
737
+ hidden_states = self.to_out[1](hidden_states)
738
+
739
+ hidden_states = rearrange(
740
+ hidden_states,
741
+ "bt (h w) c-> bt c h w",
742
+ h=height,
743
+ w=width,
744
+ )
745
+ if self.residual_connection:
746
+ hidden_states = hidden_states + residual
747
+
748
+ hidden_states = hidden_states / self.rescale_output_factor
749
+ self.print_idx += 1
750
+ return hidden_states
musev/models/controlnet.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
2
+ import warnings
3
+ import os
4
+
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ import PIL
10
+ from einops import rearrange, repeat
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.init as init
14
+ from diffusers.models.controlnet import ControlNetModel
15
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
16
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
17
+ from diffusers.utils.torch_utils import is_compiled_module
18
+
19
+
20
+ class ControlnetPredictor(object):
21
+ def __init__(self, controlnet_model_path: str, *args, **kwargs):
22
+ """Controlnet 推断函数,用于提取 controlnet backbone的emb,避免训练时重复抽取
23
+ Controlnet inference predictor, used to extract the emb of the controlnet backbone to avoid repeated extraction during training
24
+ Args:
25
+ controlnet_model_path (str): controlnet 模型路径. controlnet model path.
26
+ """
27
+ super(ControlnetPredictor, self).__init__(*args, **kwargs)
28
+ self.controlnet = ControlNetModel.from_pretrained(
29
+ controlnet_model_path,
30
+ )
31
+
32
+ def prepare_image(
33
+ self,
34
+ image, # b c t h w
35
+ width,
36
+ height,
37
+ batch_size,
38
+ num_images_per_prompt,
39
+ device,
40
+ dtype,
41
+ do_classifier_free_guidance=False,
42
+ guess_mode=False,
43
+ ):
44
+ if height is None:
45
+ height = image.shape[-2]
46
+ if width is None:
47
+ width = image.shape[-1]
48
+ width, height = (
49
+ x - x % self.control_image_processor.vae_scale_factor
50
+ for x in (width, height)
51
+ )
52
+ image = rearrange(image, "b c t h w-> (b t) c h w")
53
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
54
+
55
+ image = (
56
+ torch.nn.functional.interpolate(
57
+ image,
58
+ size=(height, width),
59
+ mode="bilinear",
60
+ ),
61
+ )
62
+
63
+ do_normalize = self.control_image_processor.config.do_normalize
64
+ if image.min() < 0:
65
+ warnings.warn(
66
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
67
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
68
+ FutureWarning,
69
+ )
70
+ do_normalize = False
71
+
72
+ if do_normalize:
73
+ image = self.control_image_processor.normalize(image)
74
+
75
+ image_batch_size = image.shape[0]
76
+
77
+ if image_batch_size == 1:
78
+ repeat_by = batch_size
79
+ else:
80
+ # image batch size is the same as prompt batch size
81
+ repeat_by = num_images_per_prompt
82
+
83
+ image = image.repeat_interleave(repeat_by, dim=0)
84
+
85
+ image = image.to(device=device, dtype=dtype)
86
+
87
+ if do_classifier_free_guidance and not guess_mode:
88
+ image = torch.cat([image] * 2)
89
+
90
+ return image
91
+
92
+ @torch.no_grad()
93
+ def __call__(
94
+ self,
95
+ batch_size: int,
96
+ device: str,
97
+ dtype: torch.dtype,
98
+ timesteps: List[float],
99
+ i: int,
100
+ scheduler: KarrasDiffusionSchedulers,
101
+ prompt_embeds: torch.Tensor,
102
+ do_classifier_free_guidance: bool = False,
103
+ # 2b co t ho wo
104
+ latent_model_input: torch.Tensor = None,
105
+ # b co t ho wo
106
+ latents: torch.Tensor = None,
107
+ # b c t h w
108
+ image: Union[
109
+ torch.FloatTensor,
110
+ PIL.Image.Image,
111
+ np.ndarray,
112
+ List[torch.FloatTensor],
113
+ List[PIL.Image.Image],
114
+ List[np.ndarray],
115
+ ] = None,
116
+ # b c t(1) hi wi
117
+ controlnet_condition_frames: Optional[torch.FloatTensor] = None,
118
+ # b c t ho wo
119
+ controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None,
120
+ # b c t(1) ho wo
121
+ controlnet_condition_latents: Optional[torch.FloatTensor] = None,
122
+ height: Optional[int] = None,
123
+ width: Optional[int] = None,
124
+ num_videos_per_prompt: Optional[int] = 1,
125
+ return_dict: bool = True,
126
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
127
+ guess_mode: bool = False,
128
+ control_guidance_start: Union[float, List[float]] = 0.0,
129
+ control_guidance_end: Union[float, List[float]] = 1.0,
130
+ latent_index: torch.LongTensor = None,
131
+ vision_condition_latent_index: torch.LongTensor = None,
132
+ **kwargs,
133
+ ):
134
+ assert (
135
+ image is None and controlnet_latents is None
136
+ ), "should set one of image and controlnet_latents"
137
+
138
+ controlnet = (
139
+ self.controlnet._orig_mod
140
+ if is_compiled_module(self.controlnet)
141
+ else self.controlnet
142
+ )
143
+
144
+ # align format for control guidance
145
+ if not isinstance(control_guidance_start, list) and isinstance(
146
+ control_guidance_end, list
147
+ ):
148
+ control_guidance_start = len(control_guidance_end) * [
149
+ control_guidance_start
150
+ ]
151
+ elif not isinstance(control_guidance_end, list) and isinstance(
152
+ control_guidance_start, list
153
+ ):
154
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
155
+ elif not isinstance(control_guidance_start, list) and not isinstance(
156
+ control_guidance_end, list
157
+ ):
158
+ mult = (
159
+ len(controlnet.nets)
160
+ if isinstance(controlnet, MultiControlNetModel)
161
+ else 1
162
+ )
163
+ control_guidance_start, control_guidance_end = mult * [
164
+ control_guidance_start
165
+ ], mult * [control_guidance_end]
166
+
167
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(
168
+ controlnet_conditioning_scale, float
169
+ ):
170
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
171
+ controlnet.nets
172
+ )
173
+
174
+ global_pool_conditions = (
175
+ controlnet.config.global_pool_conditions
176
+ if isinstance(controlnet, ControlNetModel)
177
+ else controlnet.nets[0].config.global_pool_conditions
178
+ )
179
+ guess_mode = guess_mode or global_pool_conditions
180
+
181
+ # 4. Prepare image
182
+ if isinstance(controlnet, ControlNetModel):
183
+ if (
184
+ controlnet_latents is not None
185
+ and controlnet_condition_latents is not None
186
+ ):
187
+ if isinstance(controlnet_latents, np.ndarray):
188
+ controlnet_latents = torch.from_numpy(controlnet_latents)
189
+ if isinstance(controlnet_condition_latents, np.ndarray):
190
+ controlnet_condition_latents = torch.from_numpy(
191
+ controlnet_condition_latents
192
+ )
193
+ # TODO:使用index进行concat
194
+ controlnet_latents = torch.concat(
195
+ [controlnet_condition_latents, controlnet_latents], dim=2
196
+ )
197
+ if not guess_mode and do_classifier_free_guidance:
198
+ controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0)
199
+ controlnet_latents = rearrange(
200
+ controlnet_latents, "b c t h w->(b t) c h w"
201
+ )
202
+ controlnet_latents = controlnet_latents.to(device=device, dtype=dtype)
203
+ else:
204
+ # TODO:使用index进行concat
205
+ # TODO: concat with index
206
+ if controlnet_condition_frames is not None:
207
+ if isinstance(controlnet_condition_frames, np.ndarray):
208
+ image = np.concatenate(
209
+ [controlnet_condition_frames, image], axis=2
210
+ )
211
+ image = self.prepare_image(
212
+ image=image,
213
+ width=width,
214
+ height=height,
215
+ batch_size=batch_size * num_videos_per_prompt,
216
+ num_images_per_prompt=num_videos_per_prompt,
217
+ device=device,
218
+ dtype=controlnet.dtype,
219
+ do_classifier_free_guidance=do_classifier_free_guidance,
220
+ guess_mode=guess_mode,
221
+ )
222
+ height, width = image.shape[-2:]
223
+ elif isinstance(controlnet, MultiControlNetModel):
224
+ images = []
225
+ # TODO: 支持直接使用controlnet_latent而不是frames
226
+ # TODO: support using controlnet_latent directly instead of frames
227
+ if controlnet_latents is not None:
228
+ raise NotImplementedError
229
+ else:
230
+ for i, image_ in enumerate(image):
231
+ if controlnet_condition_frames is not None and isinstance(
232
+ controlnet_condition_frames, list
233
+ ):
234
+ if isinstance(controlnet_condition_frames[i], np.ndarray):
235
+ image_ = np.concatenate(
236
+ [controlnet_condition_frames[i], image_], axis=2
237
+ )
238
+ image_ = self.prepare_image(
239
+ image=image_,
240
+ width=width,
241
+ height=height,
242
+ batch_size=batch_size * num_videos_per_prompt,
243
+ num_images_per_prompt=num_videos_per_prompt,
244
+ device=device,
245
+ dtype=controlnet.dtype,
246
+ do_classifier_free_guidance=do_classifier_free_guidance,
247
+ guess_mode=guess_mode,
248
+ )
249
+
250
+ images.append(image_)
251
+
252
+ image = images
253
+ height, width = image[0].shape[-2:]
254
+ else:
255
+ assert False
256
+
257
+ # 7.1 Create tensor stating which controlnets to keep
258
+ controlnet_keep = []
259
+ for i in range(len(timesteps)):
260
+ keeps = [
261
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
262
+ for s, e in zip(control_guidance_start, control_guidance_end)
263
+ ]
264
+ controlnet_keep.append(
265
+ keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
266
+ )
267
+
268
+ t = timesteps[i]
269
+
270
+ # controlnet(s) inference
271
+ if guess_mode and do_classifier_free_guidance:
272
+ # Infer ControlNet only for the conditional batch.
273
+ control_model_input = latents
274
+ control_model_input = scheduler.scale_model_input(control_model_input, t)
275
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
276
+ else:
277
+ control_model_input = latent_model_input
278
+ controlnet_prompt_embeds = prompt_embeds
279
+ if isinstance(controlnet_keep[i], list):
280
+ cond_scale = [
281
+ c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])
282
+ ]
283
+ else:
284
+ cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
285
+ control_model_input_reshape = rearrange(
286
+ control_model_input, "b c t h w -> (b t) c h w"
287
+ )
288
+ encoder_hidden_states_repeat = repeat(
289
+ controlnet_prompt_embeds,
290
+ "b n q->(b t) n q",
291
+ t=control_model_input.shape[2],
292
+ )
293
+
294
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
295
+ control_model_input_reshape,
296
+ t,
297
+ encoder_hidden_states_repeat,
298
+ controlnet_cond=image,
299
+ controlnet_cond_latents=controlnet_latents,
300
+ conditioning_scale=cond_scale,
301
+ guess_mode=guess_mode,
302
+ return_dict=False,
303
+ )
304
+
305
+ return down_block_res_samples, mid_block_res_sample
306
+
307
+
308
+ class InflatedConv3d(nn.Conv2d):
309
+ def forward(self, x):
310
+ video_length = x.shape[2]
311
+
312
+ x = rearrange(x, "b c f h w -> (b f) c h w")
313
+ x = super().forward(x)
314
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
315
+
316
+ return x
317
+
318
+
319
+ def zero_module(module):
320
+ # Zero out the parameters of a module and return it.
321
+ for p in module.parameters():
322
+ p.detach().zero_()
323
+ return module
324
+
325
+
326
+ class PoseGuider(ModelMixin):
327
+ def __init__(
328
+ self,
329
+ conditioning_embedding_channels: int,
330
+ conditioning_channels: int = 3,
331
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
332
+ ):
333
+ super().__init__()
334
+ self.conv_in = InflatedConv3d(
335
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
336
+ )
337
+
338
+ self.blocks = nn.ModuleList([])
339
+
340
+ for i in range(len(block_out_channels) - 1):
341
+ channel_in = block_out_channels[i]
342
+ channel_out = block_out_channels[i + 1]
343
+ self.blocks.append(
344
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
345
+ )
346
+ self.blocks.append(
347
+ InflatedConv3d(
348
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
349
+ )
350
+ )
351
+
352
+ self.conv_out = zero_module(
353
+ InflatedConv3d(
354
+ block_out_channels[-1],
355
+ conditioning_embedding_channels,
356
+ kernel_size=3,
357
+ padding=1,
358
+ )
359
+ )
360
+
361
+ def forward(self, conditioning):
362
+ embedding = self.conv_in(conditioning)
363
+ embedding = F.silu(embedding)
364
+
365
+ for block in self.blocks:
366
+ embedding = block(embedding)
367
+ embedding = F.silu(embedding)
368
+
369
+ embedding = self.conv_out(embedding)
370
+
371
+ return embedding
372
+
373
+ @classmethod
374
+ def from_pretrained(
375
+ cls,
376
+ pretrained_model_path,
377
+ conditioning_embedding_channels: int,
378
+ conditioning_channels: int = 3,
379
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
380
+ ):
381
+ if not os.path.exists(pretrained_model_path):
382
+ print(f"There is no model file in {pretrained_model_path}")
383
+ print(
384
+ f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ..."
385
+ )
386
+
387
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
388
+ model = PoseGuider(
389
+ conditioning_embedding_channels=conditioning_embedding_channels,
390
+ conditioning_channels=conditioning_channels,
391
+ block_out_channels=block_out_channels,
392
+ )
393
+
394
+ m, u = model.load_state_dict(state_dict, strict=False)
395
+ # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
396
+ params = [p.numel() for n, p in model.named_parameters()]
397
+ print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
398
+
399
+ return model
musev/models/embeddings.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from einops import rearrange
16
+ import torch
17
+ from torch.nn import functional as F
18
+ import numpy as np
19
+
20
+ from diffusers.models.embeddings import get_2d_sincos_pos_embed_from_grid
21
+
22
+
23
+ # ref diffusers.models.embeddings.get_2d_sincos_pos_embed
24
+ def get_2d_sincos_pos_embed(
25
+ embed_dim,
26
+ grid_size_w,
27
+ grid_size_h,
28
+ cls_token=False,
29
+ extra_tokens=0,
30
+ norm_length: bool = False,
31
+ max_length: float = 2048,
32
+ ):
33
+ """
34
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
35
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
36
+ """
37
+ if norm_length and grid_size_h <= max_length and grid_size_w <= max_length:
38
+ grid_h = np.linspace(0, max_length, grid_size_h)
39
+ grid_w = np.linspace(0, max_length, grid_size_w)
40
+ else:
41
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
42
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
43
+ grid = np.meshgrid(grid_h, grid_w) # here h goes first
44
+ grid = np.stack(grid, axis=0)
45
+
46
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
47
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
48
+ if cls_token and extra_tokens > 0:
49
+ pos_embed = np.concatenate(
50
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
51
+ )
52
+ return pos_embed
53
+
54
+
55
+ def resize_spatial_position_emb(
56
+ emb: torch.Tensor,
57
+ height: int,
58
+ width: int,
59
+ scale: float = None,
60
+ target_height: int = None,
61
+ target_width: int = None,
62
+ ) -> torch.Tensor:
63
+ """_summary_
64
+
65
+ Args:
66
+ emb (torch.Tensor): b ( h w) d
67
+ height (int): _description_
68
+ width (int): _description_
69
+ scale (float, optional): _description_. Defaults to None.
70
+ target_height (int, optional): _description_. Defaults to None.
71
+ target_width (int, optional): _description_. Defaults to None.
72
+
73
+ Returns:
74
+ torch.Tensor: b (target_height target_width) d
75
+ """
76
+ if scale is not None:
77
+ target_height = int(height * scale)
78
+ target_width = int(width * scale)
79
+ emb = rearrange(emb, "(h w) (b d) ->b d h w", h=height, b=1)
80
+ emb = F.interpolate(
81
+ emb,
82
+ size=(target_height, target_width),
83
+ mode="bicubic",
84
+ align_corners=False,
85
+ )
86
+ emb = rearrange(emb, "b d h w-> (h w) (b d)")
87
+ return emb
musev/models/facein_loader.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.modeling_utils import load_state_dict
31
+ from diffusers.utils import (
32
+ logging,
33
+ )
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+
36
+ from mmcm.vision.feature_extractor.clip_vision_extractor import (
37
+ ImageClipVisionFeatureExtractor,
38
+ ImageClipVisionFeatureExtractorV2,
39
+ )
40
+ from mmcm.vision.feature_extractor.insight_face_extractor import InsightFaceExtractor
41
+
42
+ from ip_adapter.resampler import Resampler
43
+ from ip_adapter.ip_adapter import ImageProjModel
44
+
45
+ from .unet_loader import update_unet_with_sd
46
+ from .unet_3d_condition import UNet3DConditionModel
47
+ from .ip_adapter_loader import ip_adapter_keys_list
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+
52
+ # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
53
+ unet_keys_list = [
54
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
55
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
56
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
57
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
58
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
59
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
60
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
61
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
62
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
63
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
64
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
65
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
66
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
67
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
68
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
69
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
70
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
71
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
72
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
73
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
74
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
75
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
76
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
77
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
78
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
79
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
80
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
81
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
82
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
83
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
84
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
85
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
86
+ ]
87
+
88
+
89
+ UNET2IPAadapter_Keys_MAPIING = {
90
+ k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
91
+ }
92
+
93
+
94
+ def load_facein_extractor_and_proj_by_name(
95
+ model_name: str,
96
+ ip_ckpt: Tuple[str, nn.Module],
97
+ ip_image_encoder: Tuple[str, nn.Module] = None,
98
+ cross_attention_dim: int = 768,
99
+ clip_embeddings_dim: int = 512,
100
+ clip_extra_context_tokens: int = 1,
101
+ ip_scale: float = 0.0,
102
+ dtype: torch.dtype = torch.float16,
103
+ device: str = "cuda",
104
+ unet: nn.Module = None,
105
+ ) -> nn.Module:
106
+ pass
107
+
108
+
109
+ def update_unet_facein_cross_attn_param(
110
+ unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
111
+ ) -> None:
112
+ """use independent ip_adapter attn 中的 to_k, to_v in unet
113
+ ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']的字典
114
+
115
+
116
+ Args:
117
+ unet (UNet3DConditionModel): _description_
118
+ ip_adapter_state_dict (Dict): _description_
119
+ """
120
+ pass
musev/models/ip_adapter_face_loader.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.modeling_utils import load_state_dict
31
+ from diffusers.utils import (
32
+ logging,
33
+ )
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+
36
+ from ip_adapter.resampler import Resampler
37
+ from ip_adapter.ip_adapter import ImageProjModel
38
+ from ip_adapter.ip_adapter_faceid import ProjPlusModel, MLPProjModel
39
+
40
+ from mmcm.vision.feature_extractor.clip_vision_extractor import (
41
+ ImageClipVisionFeatureExtractor,
42
+ ImageClipVisionFeatureExtractorV2,
43
+ )
44
+ from mmcm.vision.feature_extractor.insight_face_extractor import (
45
+ InsightFaceExtractorNormEmb,
46
+ )
47
+
48
+
49
+ from .unet_loader import update_unet_with_sd
50
+ from .unet_3d_condition import UNet3DConditionModel
51
+ from .ip_adapter_loader import ip_adapter_keys_list
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
57
+ unet_keys_list = [
58
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
59
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
60
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
61
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
62
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
63
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
64
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
65
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
66
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
67
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
68
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
69
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
70
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
71
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
72
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
73
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
74
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
75
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
76
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
77
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
78
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
79
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
80
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
81
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
82
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
83
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
84
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
85
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
86
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
87
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
88
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
89
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
90
+ ]
91
+
92
+
93
+ UNET2IPAadapter_Keys_MAPIING = {
94
+ k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
95
+ }
96
+
97
+
98
+ def load_ip_adapter_face_extractor_and_proj_by_name(
99
+ model_name: str,
100
+ ip_ckpt: Tuple[str, nn.Module],
101
+ ip_image_encoder: Tuple[str, nn.Module] = None,
102
+ cross_attention_dim: int = 768,
103
+ clip_embeddings_dim: int = 1024,
104
+ clip_extra_context_tokens: int = 4,
105
+ ip_scale: float = 0.0,
106
+ dtype: torch.dtype = torch.float16,
107
+ device: str = "cuda",
108
+ unet: nn.Module = None,
109
+ ) -> nn.Module:
110
+ if model_name == "IPAdapterFaceID":
111
+ if ip_image_encoder is not None:
112
+ ip_adapter_face_emb_extractor = InsightFaceExtractorNormEmb(
113
+ pretrained_model_name_or_path=ip_image_encoder,
114
+ dtype=dtype,
115
+ device=device,
116
+ )
117
+ else:
118
+ ip_adapter_face_emb_extractor = None
119
+ ip_adapter_image_proj = MLPProjModel(
120
+ cross_attention_dim=cross_attention_dim,
121
+ id_embeddings_dim=clip_embeddings_dim,
122
+ num_tokens=clip_extra_context_tokens,
123
+ ).to(device, dtype=dtype)
124
+ else:
125
+ raise ValueError(
126
+ f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, IPAdapterFaceID"
127
+ )
128
+ ip_adapter_state_dict = torch.load(
129
+ ip_ckpt,
130
+ map_location="cpu",
131
+ )
132
+ ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
133
+ if unet is not None and "ip_adapter" in ip_adapter_state_dict:
134
+ update_unet_ip_adapter_cross_attn_param(
135
+ unet,
136
+ ip_adapter_state_dict["ip_adapter"],
137
+ )
138
+ logger.info(
139
+ f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
140
+ )
141
+ return (
142
+ ip_adapter_face_emb_extractor,
143
+ ip_adapter_image_proj,
144
+ )
145
+
146
+
147
+ def update_unet_ip_adapter_cross_attn_param(
148
+ unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
149
+ ) -> None:
150
+ """use independent ip_adapter attn 中的 to_k, to_v in unet
151
+ ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']
152
+
153
+
154
+ Args:
155
+ unet (UNet3DConditionModel): _description_
156
+ ip_adapter_state_dict (Dict): _description_
157
+ """
158
+ unet_spatial_cross_atnns = unet.spatial_cross_attns[0]
159
+ unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns}
160
+ for i, (unet_key_more, ip_adapter_key) in enumerate(
161
+ UNET2IPAadapter_Keys_MAPIING.items()
162
+ ):
163
+ ip_adapter_value = ip_adapter_state_dict[ip_adapter_key]
164
+ unet_key_more_spit = unet_key_more.split(".")
165
+ unet_key = ".".join(unet_key_more_spit[:-3])
166
+ suffix = ".".join(unet_key_more_spit[-3:])
167
+ logger.debug(
168
+ f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}",
169
+ )
170
+ if ".ip_adapter_face_to_k" in suffix:
171
+ with torch.no_grad():
172
+ unet_spatial_cross_atnns_dct[
173
+ unet_key
174
+ ].ip_adapter_face_to_k_ip.weight.copy_(ip_adapter_value.data)
175
+ else:
176
+ with torch.no_grad():
177
+ unet_spatial_cross_atnns_dct[
178
+ unet_key
179
+ ].ip_adapter_face_to_v_ip.weight.copy_(ip_adapter_value.data)
musev/models/ip_adapter_loader.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.modeling_utils import load_state_dict
31
+ from diffusers.utils import (
32
+ logging,
33
+ )
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+
36
+ from mmcm.vision.feature_extractor import clip_vision_extractor
37
+ from mmcm.vision.feature_extractor.clip_vision_extractor import (
38
+ ImageClipVisionFeatureExtractor,
39
+ ImageClipVisionFeatureExtractorV2,
40
+ VerstailSDLastHiddenState2ImageEmb,
41
+ )
42
+
43
+ from ip_adapter.resampler import Resampler
44
+ from ip_adapter.ip_adapter import ImageProjModel
45
+
46
+ from .unet_loader import update_unet_with_sd
47
+ from .unet_3d_condition import UNet3DConditionModel
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+
52
+ def load_vision_clip_encoder_by_name(
53
+ ip_image_encoder: Tuple[str, nn.Module] = None,
54
+ dtype: torch.dtype = torch.float16,
55
+ device: str = "cuda",
56
+ vision_clip_extractor_class_name: str = None,
57
+ ) -> nn.Module:
58
+ if vision_clip_extractor_class_name is not None:
59
+ vision_clip_extractor = getattr(
60
+ clip_vision_extractor, vision_clip_extractor_class_name
61
+ )(
62
+ pretrained_model_name_or_path=ip_image_encoder,
63
+ dtype=dtype,
64
+ device=device,
65
+ )
66
+ else:
67
+ vision_clip_extractor = None
68
+ return vision_clip_extractor
69
+
70
+
71
+ def load_ip_adapter_image_proj_by_name(
72
+ model_name: str,
73
+ ip_ckpt: Tuple[str, nn.Module] = None,
74
+ cross_attention_dim: int = 768,
75
+ clip_embeddings_dim: int = 1024,
76
+ clip_extra_context_tokens: int = 4,
77
+ ip_scale: float = 0.0,
78
+ dtype: torch.dtype = torch.float16,
79
+ device: str = "cuda",
80
+ unet: nn.Module = None,
81
+ vision_clip_extractor_class_name: str = None,
82
+ ip_image_encoder: Tuple[str, nn.Module] = None,
83
+ ) -> nn.Module:
84
+ if model_name in [
85
+ "IPAdapter",
86
+ "musev_referencenet",
87
+ "musev_referencenet_pose",
88
+ ]:
89
+ ip_adapter_image_proj = ImageProjModel(
90
+ cross_attention_dim=cross_attention_dim,
91
+ clip_embeddings_dim=clip_embeddings_dim,
92
+ clip_extra_context_tokens=clip_extra_context_tokens,
93
+ )
94
+
95
+ elif model_name == "IPAdapterPlus":
96
+ vision_clip_extractor = ImageClipVisionFeatureExtractorV2(
97
+ pretrained_model_name_or_path=ip_image_encoder,
98
+ dtype=dtype,
99
+ device=device,
100
+ )
101
+ ip_adapter_image_proj = Resampler(
102
+ dim=cross_attention_dim,
103
+ depth=4,
104
+ dim_head=64,
105
+ heads=12,
106
+ num_queries=clip_extra_context_tokens,
107
+ embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size,
108
+ output_dim=cross_attention_dim,
109
+ ff_mult=4,
110
+ )
111
+ elif model_name in [
112
+ "VerstailSDLastHiddenState2ImageEmb",
113
+ "OriginLastHiddenState2ImageEmbd",
114
+ "OriginLastHiddenState2Poolout",
115
+ ]:
116
+ ip_adapter_image_proj = getattr(
117
+ clip_vision_extractor, model_name
118
+ ).from_pretrained(ip_image_encoder)
119
+ else:
120
+ raise ValueError(
121
+ f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, VerstailSDLastHiddenState2ImageEmb"
122
+ )
123
+ if ip_ckpt is not None:
124
+ ip_adapter_state_dict = torch.load(
125
+ ip_ckpt,
126
+ map_location="cpu",
127
+ )
128
+ ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
129
+ if (
130
+ unet is not None
131
+ and unet.ip_adapter_cross_attn
132
+ and "ip_adapter" in ip_adapter_state_dict
133
+ ):
134
+ update_unet_ip_adapter_cross_attn_param(
135
+ unet, ip_adapter_state_dict["ip_adapter"]
136
+ )
137
+ logger.info(
138
+ f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
139
+ )
140
+ return ip_adapter_image_proj
141
+
142
+
143
+ def load_ip_adapter_vision_clip_encoder_by_name(
144
+ model_name: str,
145
+ ip_ckpt: Tuple[str, nn.Module],
146
+ ip_image_encoder: Tuple[str, nn.Module] = None,
147
+ cross_attention_dim: int = 768,
148
+ clip_embeddings_dim: int = 1024,
149
+ clip_extra_context_tokens: int = 4,
150
+ ip_scale: float = 0.0,
151
+ dtype: torch.dtype = torch.float16,
152
+ device: str = "cuda",
153
+ unet: nn.Module = None,
154
+ vision_clip_extractor_class_name: str = None,
155
+ ) -> nn.Module:
156
+ if vision_clip_extractor_class_name is not None:
157
+ vision_clip_extractor = getattr(
158
+ clip_vision_extractor, vision_clip_extractor_class_name
159
+ )(
160
+ pretrained_model_name_or_path=ip_image_encoder,
161
+ dtype=dtype,
162
+ device=device,
163
+ )
164
+ else:
165
+ vision_clip_extractor = None
166
+ if model_name in [
167
+ "IPAdapter",
168
+ "musev_referencenet",
169
+ ]:
170
+ if ip_image_encoder is not None:
171
+ if vision_clip_extractor_class_name is None:
172
+ vision_clip_extractor = ImageClipVisionFeatureExtractor(
173
+ pretrained_model_name_or_path=ip_image_encoder,
174
+ dtype=dtype,
175
+ device=device,
176
+ )
177
+ else:
178
+ vision_clip_extractor = None
179
+ ip_adapter_image_proj = ImageProjModel(
180
+ cross_attention_dim=cross_attention_dim,
181
+ clip_embeddings_dim=clip_embeddings_dim,
182
+ clip_extra_context_tokens=clip_extra_context_tokens,
183
+ )
184
+
185
+ elif model_name == "IPAdapterPlus":
186
+ if ip_image_encoder is not None:
187
+ if vision_clip_extractor_class_name is None:
188
+ vision_clip_extractor = ImageClipVisionFeatureExtractorV2(
189
+ pretrained_model_name_or_path=ip_image_encoder,
190
+ dtype=dtype,
191
+ device=device,
192
+ )
193
+ else:
194
+ vision_clip_extractor = None
195
+ ip_adapter_image_proj = Resampler(
196
+ dim=cross_attention_dim,
197
+ depth=4,
198
+ dim_head=64,
199
+ heads=12,
200
+ num_queries=clip_extra_context_tokens,
201
+ embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size,
202
+ output_dim=cross_attention_dim,
203
+ ff_mult=4,
204
+ ).to(dtype=torch.float16)
205
+ else:
206
+ raise ValueError(
207
+ f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus"
208
+ )
209
+ ip_adapter_state_dict = torch.load(
210
+ ip_ckpt,
211
+ map_location="cpu",
212
+ )
213
+ ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
214
+ if (
215
+ unet is not None
216
+ and unet.ip_adapter_cross_attn
217
+ and "ip_adapter" in ip_adapter_state_dict
218
+ ):
219
+ update_unet_ip_adapter_cross_attn_param(
220
+ unet, ip_adapter_state_dict["ip_adapter"]
221
+ )
222
+ logger.info(
223
+ f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
224
+ )
225
+ return (
226
+ vision_clip_extractor,
227
+ ip_adapter_image_proj,
228
+ )
229
+
230
+
231
+ # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
232
+ unet_keys_list = [
233
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
234
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
235
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
236
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
237
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
238
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
239
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
240
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
241
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
242
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
243
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
244
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
245
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
246
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
247
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
248
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
249
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
250
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
251
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
252
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
253
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
254
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
255
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
256
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
257
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
258
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
259
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
260
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
261
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
262
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
263
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
264
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
265
+ ]
266
+
267
+
268
+ ip_adapter_keys_list = [
269
+ "1.to_k_ip.weight",
270
+ "1.to_v_ip.weight",
271
+ "3.to_k_ip.weight",
272
+ "3.to_v_ip.weight",
273
+ "5.to_k_ip.weight",
274
+ "5.to_v_ip.weight",
275
+ "7.to_k_ip.weight",
276
+ "7.to_v_ip.weight",
277
+ "9.to_k_ip.weight",
278
+ "9.to_v_ip.weight",
279
+ "11.to_k_ip.weight",
280
+ "11.to_v_ip.weight",
281
+ "13.to_k_ip.weight",
282
+ "13.to_v_ip.weight",
283
+ "15.to_k_ip.weight",
284
+ "15.to_v_ip.weight",
285
+ "17.to_k_ip.weight",
286
+ "17.to_v_ip.weight",
287
+ "19.to_k_ip.weight",
288
+ "19.to_v_ip.weight",
289
+ "21.to_k_ip.weight",
290
+ "21.to_v_ip.weight",
291
+ "23.to_k_ip.weight",
292
+ "23.to_v_ip.weight",
293
+ "25.to_k_ip.weight",
294
+ "25.to_v_ip.weight",
295
+ "27.to_k_ip.weight",
296
+ "27.to_v_ip.weight",
297
+ "29.to_k_ip.weight",
298
+ "29.to_v_ip.weight",
299
+ "31.to_k_ip.weight",
300
+ "31.to_v_ip.weight",
301
+ ]
302
+
303
+ UNET2IPAadapter_Keys_MAPIING = {
304
+ k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
305
+ }
306
+
307
+
308
+ def update_unet_ip_adapter_cross_attn_param(
309
+ unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
310
+ ) -> None:
311
+ """use independent ip_adapter attn 中的 to_k, to_v in unet
312
+ ip_adapter: dict whose keys are ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']
313
+
314
+
315
+ Args:
316
+ unet (UNet3DConditionModel): _description_
317
+ ip_adapter_state_dict (Dict): _description_
318
+ """
319
+ unet_spatial_cross_atnns = unet.spatial_cross_attns[0]
320
+ unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns}
321
+ for i, (unet_key_more, ip_adapter_key) in enumerate(
322
+ UNET2IPAadapter_Keys_MAPIING.items()
323
+ ):
324
+ ip_adapter_value = ip_adapter_state_dict[ip_adapter_key]
325
+ unet_key_more_spit = unet_key_more.split(".")
326
+ unet_key = ".".join(unet_key_more_spit[:-3])
327
+ suffix = ".".join(unet_key_more_spit[-3:])
328
+ logger.debug(
329
+ f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}",
330
+ )
331
+ if "to_k" in suffix:
332
+ with torch.no_grad():
333
+ unet_spatial_cross_atnns_dct[unet_key].to_k_ip.weight.copy_(
334
+ ip_adapter_value.data
335
+ )
336
+ else:
337
+ with torch.no_grad():
338
+ unet_spatial_cross_atnns_dct[unet_key].to_v_ip.weight.copy_(
339
+ ip_adapter_value.data
340
+ )
musev/models/referencenet.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+ import logging
19
+
20
+ import torch
21
+ from diffusers.models.attention_processor import Attention, AttnProcessor
22
+ from einops import rearrange, repeat
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import xformers
26
+ from diffusers.models.lora import LoRACompatibleLinear
27
+ from diffusers.models.unet_2d_condition import (
28
+ UNet2DConditionModel,
29
+ UNet2DConditionOutput,
30
+ )
31
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
32
+ from diffusers.utils.constants import USE_PEFT_BACKEND
33
+ from diffusers.utils.deprecation_utils import deprecate
34
+ from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict
37
+ from diffusers.loaders import UNet2DConditionLoadersMixin
38
+ from diffusers.utils import (
39
+ USE_PEFT_BACKEND,
40
+ BaseOutput,
41
+ deprecate,
42
+ scale_lora_layers,
43
+ unscale_lora_layers,
44
+ )
45
+ from diffusers.models.activations import get_activation
46
+ from diffusers.models.attention_processor import (
47
+ ADDED_KV_ATTENTION_PROCESSORS,
48
+ CROSS_ATTENTION_PROCESSORS,
49
+ AttentionProcessor,
50
+ AttnAddedKVProcessor,
51
+ AttnProcessor,
52
+ )
53
+ from diffusers.models.embeddings import (
54
+ GaussianFourierProjection,
55
+ ImageHintTimeEmbedding,
56
+ ImageProjection,
57
+ ImageTimeEmbedding,
58
+ PositionNet,
59
+ TextImageProjection,
60
+ TextImageTimeEmbedding,
61
+ TextTimeEmbedding,
62
+ TimestepEmbedding,
63
+ Timesteps,
64
+ )
65
+ from diffusers.models.modeling_utils import ModelMixin
66
+
67
+
68
+ from ..data.data_util import align_repeat_tensor_single_dim
69
+ from .unet_3d_condition import UNet3DConditionModel
70
+ from .attention import BasicTransformerBlock, IPAttention
71
+ from .unet_2d_blocks import (
72
+ UNetMidBlock2D,
73
+ UNetMidBlock2DCrossAttn,
74
+ UNetMidBlock2DSimpleCrossAttn,
75
+ get_down_block,
76
+ get_up_block,
77
+ )
78
+
79
+ from . import Model_Register
80
+
81
+
82
+ logger = logging.getLogger(__name__)
83
+
84
+
85
+ @Model_Register.register
86
+ class ReferenceNet2D(UNet2DConditionModel, nn.Module):
87
+ """继承 UNet2DConditionModel. 新增功能,类似controlnet 返回模型中间特征,用于后续作用
88
+ Inherit Unet2DConditionModel. Add new functions, similar to controlnet, return the intermediate features of the model for subsequent effects
89
+ Args:
90
+ UNet2DConditionModel (_type_): _description_
91
+ """
92
+
93
+ _supports_gradient_checkpointing = True
94
+ print_idx = 0
95
+
96
+ @register_to_config
97
+ def __init__(
98
+ self,
99
+ sample_size: int | None = None,
100
+ in_channels: int = 4,
101
+ out_channels: int = 4,
102
+ center_input_sample: bool = False,
103
+ flip_sin_to_cos: bool = True,
104
+ freq_shift: int = 0,
105
+ down_block_types: Tuple[str] = (
106
+ "CrossAttnDownBlock2D",
107
+ "CrossAttnDownBlock2D",
108
+ "CrossAttnDownBlock2D",
109
+ "DownBlock2D",
110
+ ),
111
+ mid_block_type: str | None = "UNetMidBlock2DCrossAttn",
112
+ up_block_types: Tuple[str] = (
113
+ "UpBlock2D",
114
+ "CrossAttnUpBlock2D",
115
+ "CrossAttnUpBlock2D",
116
+ "CrossAttnUpBlock2D",
117
+ ),
118
+ only_cross_attention: bool | Tuple[bool] = False,
119
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
120
+ layers_per_block: int | Tuple[int] = 2,
121
+ downsample_padding: int = 1,
122
+ mid_block_scale_factor: float = 1,
123
+ dropout: float = 0,
124
+ act_fn: str = "silu",
125
+ norm_num_groups: int | None = 32,
126
+ norm_eps: float = 0.00001,
127
+ cross_attention_dim: int | Tuple[int] = 1280,
128
+ transformer_layers_per_block: int | Tuple[int] | Tuple[Tuple] = 1,
129
+ reverse_transformer_layers_per_block: Tuple[Tuple[int]] | None = None,
130
+ encoder_hid_dim: int | None = None,
131
+ encoder_hid_dim_type: str | None = None,
132
+ attention_head_dim: int | Tuple[int] = 8,
133
+ num_attention_heads: int | Tuple[int] | None = None,
134
+ dual_cross_attention: bool = False,
135
+ use_linear_projection: bool = False,
136
+ class_embed_type: str | None = None,
137
+ addition_embed_type: str | None = None,
138
+ addition_time_embed_dim: int | None = None,
139
+ num_class_embeds: int | None = None,
140
+ upcast_attention: bool = False,
141
+ resnet_time_scale_shift: str = "default",
142
+ resnet_skip_time_act: bool = False,
143
+ resnet_out_scale_factor: int = 1,
144
+ time_embedding_type: str = "positional",
145
+ time_embedding_dim: int | None = None,
146
+ time_embedding_act_fn: str | None = None,
147
+ timestep_post_act: str | None = None,
148
+ time_cond_proj_dim: int | None = None,
149
+ conv_in_kernel: int = 3,
150
+ conv_out_kernel: int = 3,
151
+ projection_class_embeddings_input_dim: int | None = None,
152
+ attention_type: str = "default",
153
+ class_embeddings_concat: bool = False,
154
+ mid_block_only_cross_attention: bool | None = None,
155
+ cross_attention_norm: str | None = None,
156
+ addition_embed_type_num_heads=64,
157
+ need_self_attn_block_embs: bool = False,
158
+ need_block_embs: bool = False,
159
+ ):
160
+ super().__init__()
161
+
162
+ self.sample_size = sample_size
163
+
164
+ if num_attention_heads is not None:
165
+ raise ValueError(
166
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
167
+ )
168
+
169
+ # If `num_attention_heads` is not defined (which is the case for most models)
170
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
171
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
172
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
173
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
174
+ # which is why we correct for the naming here.
175
+ num_attention_heads = num_attention_heads or attention_head_dim
176
+
177
+ # Check inputs
178
+ if len(down_block_types) != len(up_block_types):
179
+ raise ValueError(
180
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
181
+ )
182
+
183
+ if len(block_out_channels) != len(down_block_types):
184
+ raise ValueError(
185
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
186
+ )
187
+
188
+ if not isinstance(only_cross_attention, bool) and len(
189
+ only_cross_attention
190
+ ) != len(down_block_types):
191
+ raise ValueError(
192
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
193
+ )
194
+
195
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
196
+ down_block_types
197
+ ):
198
+ raise ValueError(
199
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
200
+ )
201
+
202
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
203
+ down_block_types
204
+ ):
205
+ raise ValueError(
206
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
207
+ )
208
+
209
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
210
+ down_block_types
211
+ ):
212
+ raise ValueError(
213
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
214
+ )
215
+
216
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
217
+ down_block_types
218
+ ):
219
+ raise ValueError(
220
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
221
+ )
222
+ if (
223
+ isinstance(transformer_layers_per_block, list)
224
+ and reverse_transformer_layers_per_block is None
225
+ ):
226
+ for layer_number_per_block in transformer_layers_per_block:
227
+ if isinstance(layer_number_per_block, list):
228
+ raise ValueError(
229
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
230
+ )
231
+
232
+ # input
233
+ conv_in_padding = (conv_in_kernel - 1) // 2
234
+ self.conv_in = nn.Conv2d(
235
+ in_channels,
236
+ block_out_channels[0],
237
+ kernel_size=conv_in_kernel,
238
+ padding=conv_in_padding,
239
+ )
240
+
241
+ # time
242
+ if time_embedding_type == "fourier":
243
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
244
+ if time_embed_dim % 2 != 0:
245
+ raise ValueError(
246
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
247
+ )
248
+ self.time_proj = GaussianFourierProjection(
249
+ time_embed_dim // 2,
250
+ set_W_to_weight=False,
251
+ log=False,
252
+ flip_sin_to_cos=flip_sin_to_cos,
253
+ )
254
+ timestep_input_dim = time_embed_dim
255
+ elif time_embedding_type == "positional":
256
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
257
+
258
+ self.time_proj = Timesteps(
259
+ block_out_channels[0], flip_sin_to_cos, freq_shift
260
+ )
261
+ timestep_input_dim = block_out_channels[0]
262
+ else:
263
+ raise ValueError(
264
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
265
+ )
266
+
267
+ self.time_embedding = TimestepEmbedding(
268
+ timestep_input_dim,
269
+ time_embed_dim,
270
+ act_fn=act_fn,
271
+ post_act_fn=timestep_post_act,
272
+ cond_proj_dim=time_cond_proj_dim,
273
+ )
274
+
275
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
276
+ encoder_hid_dim_type = "text_proj"
277
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
278
+ logger.info(
279
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
280
+ )
281
+
282
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
283
+ raise ValueError(
284
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
285
+ )
286
+
287
+ if encoder_hid_dim_type == "text_proj":
288
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
289
+ elif encoder_hid_dim_type == "text_image_proj":
290
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
291
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
292
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
293
+ self.encoder_hid_proj = TextImageProjection(
294
+ text_embed_dim=encoder_hid_dim,
295
+ image_embed_dim=cross_attention_dim,
296
+ cross_attention_dim=cross_attention_dim,
297
+ )
298
+ elif encoder_hid_dim_type == "image_proj":
299
+ # Kandinsky 2.2
300
+ self.encoder_hid_proj = ImageProjection(
301
+ image_embed_dim=encoder_hid_dim,
302
+ cross_attention_dim=cross_attention_dim,
303
+ )
304
+ elif encoder_hid_dim_type is not None:
305
+ raise ValueError(
306
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
307
+ )
308
+ else:
309
+ self.encoder_hid_proj = None
310
+
311
+ # class embedding
312
+ if class_embed_type is None and num_class_embeds is not None:
313
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
314
+ elif class_embed_type == "timestep":
315
+ self.class_embedding = TimestepEmbedding(
316
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
317
+ )
318
+ elif class_embed_type == "identity":
319
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
320
+ elif class_embed_type == "projection":
321
+ if projection_class_embeddings_input_dim is None:
322
+ raise ValueError(
323
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
324
+ )
325
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
326
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
327
+ # 2. it projects from an arbitrary input dimension.
328
+ #
329
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
330
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
331
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
332
+ self.class_embedding = TimestepEmbedding(
333
+ projection_class_embeddings_input_dim, time_embed_dim
334
+ )
335
+ elif class_embed_type == "simple_projection":
336
+ if projection_class_embeddings_input_dim is None:
337
+ raise ValueError(
338
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
339
+ )
340
+ self.class_embedding = nn.Linear(
341
+ projection_class_embeddings_input_dim, time_embed_dim
342
+ )
343
+ else:
344
+ self.class_embedding = None
345
+
346
+ if addition_embed_type == "text":
347
+ if encoder_hid_dim is not None:
348
+ text_time_embedding_from_dim = encoder_hid_dim
349
+ else:
350
+ text_time_embedding_from_dim = cross_attention_dim
351
+
352
+ self.add_embedding = TextTimeEmbedding(
353
+ text_time_embedding_from_dim,
354
+ time_embed_dim,
355
+ num_heads=addition_embed_type_num_heads,
356
+ )
357
+ elif addition_embed_type == "text_image":
358
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
359
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
360
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
361
+ self.add_embedding = TextImageTimeEmbedding(
362
+ text_embed_dim=cross_attention_dim,
363
+ image_embed_dim=cross_attention_dim,
364
+ time_embed_dim=time_embed_dim,
365
+ )
366
+ elif addition_embed_type == "text_time":
367
+ self.add_time_proj = Timesteps(
368
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
369
+ )
370
+ self.add_embedding = TimestepEmbedding(
371
+ projection_class_embeddings_input_dim, time_embed_dim
372
+ )
373
+ elif addition_embed_type == "image":
374
+ # Kandinsky 2.2
375
+ self.add_embedding = ImageTimeEmbedding(
376
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
377
+ )
378
+ elif addition_embed_type == "image_hint":
379
+ # Kandinsky 2.2 ControlNet
380
+ self.add_embedding = ImageHintTimeEmbedding(
381
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
382
+ )
383
+ elif addition_embed_type is not None:
384
+ raise ValueError(
385
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
386
+ )
387
+
388
+ if time_embedding_act_fn is None:
389
+ self.time_embed_act = None
390
+ else:
391
+ self.time_embed_act = get_activation(time_embedding_act_fn)
392
+
393
+ self.down_blocks = nn.ModuleList([])
394
+ self.up_blocks = nn.ModuleList([])
395
+
396
+ if isinstance(only_cross_attention, bool):
397
+ if mid_block_only_cross_attention is None:
398
+ mid_block_only_cross_attention = only_cross_attention
399
+
400
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
401
+
402
+ if mid_block_only_cross_attention is None:
403
+ mid_block_only_cross_attention = False
404
+
405
+ if isinstance(num_attention_heads, int):
406
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
407
+
408
+ if isinstance(attention_head_dim, int):
409
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
410
+
411
+ if isinstance(cross_attention_dim, int):
412
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
413
+
414
+ if isinstance(layers_per_block, int):
415
+ layers_per_block = [layers_per_block] * len(down_block_types)
416
+
417
+ if isinstance(transformer_layers_per_block, int):
418
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
419
+ down_block_types
420
+ )
421
+
422
+ if class_embeddings_concat:
423
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
424
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
425
+ # regular time embeddings
426
+ blocks_time_embed_dim = time_embed_dim * 2
427
+ else:
428
+ blocks_time_embed_dim = time_embed_dim
429
+
430
+ # down
431
+ output_channel = block_out_channels[0]
432
+ for i, down_block_type in enumerate(down_block_types):
433
+ input_channel = output_channel
434
+ output_channel = block_out_channels[i]
435
+ is_final_block = i == len(block_out_channels) - 1
436
+
437
+ down_block = get_down_block(
438
+ down_block_type,
439
+ num_layers=layers_per_block[i],
440
+ transformer_layers_per_block=transformer_layers_per_block[i],
441
+ in_channels=input_channel,
442
+ out_channels=output_channel,
443
+ temb_channels=blocks_time_embed_dim,
444
+ add_downsample=not is_final_block,
445
+ resnet_eps=norm_eps,
446
+ resnet_act_fn=act_fn,
447
+ resnet_groups=norm_num_groups,
448
+ cross_attention_dim=cross_attention_dim[i],
449
+ num_attention_heads=num_attention_heads[i],
450
+ downsample_padding=downsample_padding,
451
+ dual_cross_attention=dual_cross_attention,
452
+ use_linear_projection=use_linear_projection,
453
+ only_cross_attention=only_cross_attention[i],
454
+ upcast_attention=upcast_attention,
455
+ resnet_time_scale_shift=resnet_time_scale_shift,
456
+ attention_type=attention_type,
457
+ resnet_skip_time_act=resnet_skip_time_act,
458
+ resnet_out_scale_factor=resnet_out_scale_factor,
459
+ cross_attention_norm=cross_attention_norm,
460
+ attention_head_dim=attention_head_dim[i]
461
+ if attention_head_dim[i] is not None
462
+ else output_channel,
463
+ dropout=dropout,
464
+ )
465
+ self.down_blocks.append(down_block)
466
+
467
+ # mid
468
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
469
+ self.mid_block = UNetMidBlock2DCrossAttn(
470
+ transformer_layers_per_block=transformer_layers_per_block[-1],
471
+ in_channels=block_out_channels[-1],
472
+ temb_channels=blocks_time_embed_dim,
473
+ dropout=dropout,
474
+ resnet_eps=norm_eps,
475
+ resnet_act_fn=act_fn,
476
+ output_scale_factor=mid_block_scale_factor,
477
+ resnet_time_scale_shift=resnet_time_scale_shift,
478
+ cross_attention_dim=cross_attention_dim[-1],
479
+ num_attention_heads=num_attention_heads[-1],
480
+ resnet_groups=norm_num_groups,
481
+ dual_cross_attention=dual_cross_attention,
482
+ use_linear_projection=use_linear_projection,
483
+ upcast_attention=upcast_attention,
484
+ attention_type=attention_type,
485
+ )
486
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
487
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
488
+ in_channels=block_out_channels[-1],
489
+ temb_channels=blocks_time_embed_dim,
490
+ dropout=dropout,
491
+ resnet_eps=norm_eps,
492
+ resnet_act_fn=act_fn,
493
+ output_scale_factor=mid_block_scale_factor,
494
+ cross_attention_dim=cross_attention_dim[-1],
495
+ attention_head_dim=attention_head_dim[-1],
496
+ resnet_groups=norm_num_groups,
497
+ resnet_time_scale_shift=resnet_time_scale_shift,
498
+ skip_time_act=resnet_skip_time_act,
499
+ only_cross_attention=mid_block_only_cross_attention,
500
+ cross_attention_norm=cross_attention_norm,
501
+ )
502
+ elif mid_block_type == "UNetMidBlock2D":
503
+ self.mid_block = UNetMidBlock2D(
504
+ in_channels=block_out_channels[-1],
505
+ temb_channels=blocks_time_embed_dim,
506
+ dropout=dropout,
507
+ num_layers=0,
508
+ resnet_eps=norm_eps,
509
+ resnet_act_fn=act_fn,
510
+ output_scale_factor=mid_block_scale_factor,
511
+ resnet_groups=norm_num_groups,
512
+ resnet_time_scale_shift=resnet_time_scale_shift,
513
+ add_attention=False,
514
+ )
515
+ elif mid_block_type is None:
516
+ self.mid_block = None
517
+ else:
518
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
519
+
520
+ # count how many layers upsample the images
521
+ self.num_upsamplers = 0
522
+
523
+ # up
524
+ reversed_block_out_channels = list(reversed(block_out_channels))
525
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
526
+ reversed_layers_per_block = list(reversed(layers_per_block))
527
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
528
+ reversed_transformer_layers_per_block = (
529
+ list(reversed(transformer_layers_per_block))
530
+ if reverse_transformer_layers_per_block is None
531
+ else reverse_transformer_layers_per_block
532
+ )
533
+ only_cross_attention = list(reversed(only_cross_attention))
534
+
535
+ output_channel = reversed_block_out_channels[0]
536
+ for i, up_block_type in enumerate(up_block_types):
537
+ is_final_block = i == len(block_out_channels) - 1
538
+
539
+ prev_output_channel = output_channel
540
+ output_channel = reversed_block_out_channels[i]
541
+ input_channel = reversed_block_out_channels[
542
+ min(i + 1, len(block_out_channels) - 1)
543
+ ]
544
+
545
+ # add upsample block for all BUT final layer
546
+ if not is_final_block:
547
+ add_upsample = True
548
+ self.num_upsamplers += 1
549
+ else:
550
+ add_upsample = False
551
+
552
+ up_block = get_up_block(
553
+ up_block_type,
554
+ num_layers=reversed_layers_per_block[i] + 1,
555
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
556
+ in_channels=input_channel,
557
+ out_channels=output_channel,
558
+ prev_output_channel=prev_output_channel,
559
+ temb_channels=blocks_time_embed_dim,
560
+ add_upsample=add_upsample,
561
+ resnet_eps=norm_eps,
562
+ resnet_act_fn=act_fn,
563
+ resolution_idx=i,
564
+ resnet_groups=norm_num_groups,
565
+ cross_attention_dim=reversed_cross_attention_dim[i],
566
+ num_attention_heads=reversed_num_attention_heads[i],
567
+ dual_cross_attention=dual_cross_attention,
568
+ use_linear_projection=use_linear_projection,
569
+ only_cross_attention=only_cross_attention[i],
570
+ upcast_attention=upcast_attention,
571
+ resnet_time_scale_shift=resnet_time_scale_shift,
572
+ attention_type=attention_type,
573
+ resnet_skip_time_act=resnet_skip_time_act,
574
+ resnet_out_scale_factor=resnet_out_scale_factor,
575
+ cross_attention_norm=cross_attention_norm,
576
+ attention_head_dim=attention_head_dim[i]
577
+ if attention_head_dim[i] is not None
578
+ else output_channel,
579
+ dropout=dropout,
580
+ )
581
+ self.up_blocks.append(up_block)
582
+ prev_output_channel = output_channel
583
+
584
+ # out
585
+ if norm_num_groups is not None:
586
+ self.conv_norm_out = nn.GroupNorm(
587
+ num_channels=block_out_channels[0],
588
+ num_groups=norm_num_groups,
589
+ eps=norm_eps,
590
+ )
591
+
592
+ self.conv_act = get_activation(act_fn)
593
+
594
+ else:
595
+ self.conv_norm_out = None
596
+ self.conv_act = None
597
+
598
+ conv_out_padding = (conv_out_kernel - 1) // 2
599
+ self.conv_out = nn.Conv2d(
600
+ block_out_channels[0],
601
+ out_channels,
602
+ kernel_size=conv_out_kernel,
603
+ padding=conv_out_padding,
604
+ )
605
+
606
+ if attention_type in ["gated", "gated-text-image"]:
607
+ positive_len = 768
608
+ if isinstance(cross_attention_dim, int):
609
+ positive_len = cross_attention_dim
610
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
611
+ cross_attention_dim, list
612
+ ):
613
+ positive_len = cross_attention_dim[0]
614
+
615
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
616
+ self.position_net = PositionNet(
617
+ positive_len=positive_len,
618
+ out_dim=cross_attention_dim,
619
+ feature_type=feature_type,
620
+ )
621
+ self.need_block_embs = need_block_embs
622
+ self.need_self_attn_block_embs = need_self_attn_block_embs
623
+
624
+ # only use referencenet soma layers, other layers set None
625
+ self.conv_norm_out = None
626
+ self.conv_act = None
627
+ self.conv_out = None
628
+
629
+ self.up_blocks[-1].attentions[-1].proj_out = None
630
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn1 = None
631
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn2 = None
632
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm2 = None
633
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].ff = None
634
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm3 = None
635
+ if not self.need_self_attn_block_embs:
636
+ self.up_blocks = None
637
+
638
+ self.insert_spatial_self_attn_idx()
639
+
640
+ def forward(
641
+ self,
642
+ sample: torch.FloatTensor,
643
+ timestep: Union[torch.Tensor, float, int],
644
+ encoder_hidden_states: torch.Tensor,
645
+ class_labels: Optional[torch.Tensor] = None,
646
+ timestep_cond: Optional[torch.Tensor] = None,
647
+ attention_mask: Optional[torch.Tensor] = None,
648
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
649
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
650
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
651
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
652
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
653
+ encoder_attention_mask: Optional[torch.Tensor] = None,
654
+ return_dict: bool = True,
655
+ # update new paramestes start
656
+ num_frames: int = None,
657
+ return_ndim: int = 5,
658
+ # update new paramestes end
659
+ ) -> Union[UNet2DConditionOutput, Tuple]:
660
+ r"""
661
+ The [`UNet2DConditionModel`] forward method.
662
+
663
+ Args:
664
+ sample (`torch.FloatTensor`):
665
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
666
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
667
+ encoder_hidden_states (`torch.FloatTensor`):
668
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
669
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
670
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
671
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
672
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
673
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
674
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
675
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
676
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
677
+ negative values to the attention scores corresponding to "discard" tokens.
678
+ cross_attention_kwargs (`dict`, *optional*):
679
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
680
+ `self.processor` in
681
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
682
+ added_cond_kwargs: (`dict`, *optional*):
683
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
684
+ are passed along to the UNet blocks.
685
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
686
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
687
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
688
+ A tensor that if specified is added to the residual of the middle unet block.
689
+ encoder_attention_mask (`torch.Tensor`):
690
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
691
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
692
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
693
+ return_dict (`bool`, *optional*, defaults to `True`):
694
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
695
+ tuple.
696
+ cross_attention_kwargs (`dict`, *optional*):
697
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
698
+ added_cond_kwargs: (`dict`, *optional*):
699
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
700
+ are passed along to the UNet blocks.
701
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
702
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
703
+ example from ControlNet side model(s)
704
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
705
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
706
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
707
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
708
+
709
+ Returns:
710
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
711
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
712
+ a `tuple` is returned where the first element is the sample tensor.
713
+ """
714
+
715
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
716
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
717
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
718
+ # on the fly if necessary.
719
+ default_overall_up_factor = 2**self.num_upsamplers
720
+
721
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
722
+ forward_upsample_size = False
723
+ upsample_size = None
724
+
725
+ for dim in sample.shape[-2:]:
726
+ if dim % default_overall_up_factor != 0:
727
+ # Forward upsample size to force interpolation output size.
728
+ forward_upsample_size = True
729
+ break
730
+
731
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
732
+ # expects mask of shape:
733
+ # [batch, key_tokens]
734
+ # adds singleton query_tokens dimension:
735
+ # [batch, 1, key_tokens]
736
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
737
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
738
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
739
+ if attention_mask is not None:
740
+ # assume that mask is expressed as:
741
+ # (1 = keep, 0 = discard)
742
+ # convert mask into a bias that can be added to attention scores:
743
+ # (keep = +0, discard = -10000.0)
744
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
745
+ attention_mask = attention_mask.unsqueeze(1)
746
+
747
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
748
+ if encoder_attention_mask is not None:
749
+ encoder_attention_mask = (
750
+ 1 - encoder_attention_mask.to(sample.dtype)
751
+ ) * -10000.0
752
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
753
+
754
+ # 0. center input if necessary
755
+ if self.config.center_input_sample:
756
+ sample = 2 * sample - 1.0
757
+
758
+ # 1. time
759
+ timesteps = timestep
760
+ if not torch.is_tensor(timesteps):
761
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
762
+ # This would be a good case for the `match` statement (Python 3.10+)
763
+ is_mps = sample.device.type == "mps"
764
+ if isinstance(timestep, float):
765
+ dtype = torch.float32 if is_mps else torch.float64
766
+ else:
767
+ dtype = torch.int32 if is_mps else torch.int64
768
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
769
+ elif len(timesteps.shape) == 0:
770
+ timesteps = timesteps[None].to(sample.device)
771
+
772
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
773
+ timesteps = timesteps.expand(sample.shape[0])
774
+
775
+ t_emb = self.time_proj(timesteps)
776
+
777
+ # `Timesteps` does not contain any weights and will always return f32 tensors
778
+ # but time_embedding might actually be running in fp16. so we need to cast here.
779
+ # there might be better ways to encapsulate this.
780
+ t_emb = t_emb.to(dtype=sample.dtype)
781
+
782
+ emb = self.time_embedding(t_emb, timestep_cond)
783
+ aug_emb = None
784
+
785
+ if self.class_embedding is not None:
786
+ if class_labels is None:
787
+ raise ValueError(
788
+ "class_labels should be provided when num_class_embeds > 0"
789
+ )
790
+
791
+ if self.config.class_embed_type == "timestep":
792
+ class_labels = self.time_proj(class_labels)
793
+
794
+ # `Timesteps` does not contain any weights and will always return f32 tensors
795
+ # there might be better ways to encapsulate this.
796
+ class_labels = class_labels.to(dtype=sample.dtype)
797
+
798
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
799
+
800
+ if self.config.class_embeddings_concat:
801
+ emb = torch.cat([emb, class_emb], dim=-1)
802
+ else:
803
+ emb = emb + class_emb
804
+
805
+ if self.config.addition_embed_type == "text":
806
+ aug_emb = self.add_embedding(encoder_hidden_states)
807
+ elif self.config.addition_embed_type == "text_image":
808
+ # Kandinsky 2.1 - style
809
+ if "image_embeds" not in added_cond_kwargs:
810
+ raise ValueError(
811
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
812
+ )
813
+
814
+ image_embs = added_cond_kwargs.get("image_embeds")
815
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
816
+ aug_emb = self.add_embedding(text_embs, image_embs)
817
+ elif self.config.addition_embed_type == "text_time":
818
+ # SDXL - style
819
+ if "text_embeds" not in added_cond_kwargs:
820
+ raise ValueError(
821
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
822
+ )
823
+ text_embeds = added_cond_kwargs.get("text_embeds")
824
+ if "time_ids" not in added_cond_kwargs:
825
+ raise ValueError(
826
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
827
+ )
828
+ time_ids = added_cond_kwargs.get("time_ids")
829
+ time_embeds = self.add_time_proj(time_ids.flatten())
830
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
831
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
832
+ add_embeds = add_embeds.to(emb.dtype)
833
+ aug_emb = self.add_embedding(add_embeds)
834
+ elif self.config.addition_embed_type == "image":
835
+ # Kandinsky 2.2 - style
836
+ if "image_embeds" not in added_cond_kwargs:
837
+ raise ValueError(
838
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
839
+ )
840
+ image_embs = added_cond_kwargs.get("image_embeds")
841
+ aug_emb = self.add_embedding(image_embs)
842
+ elif self.config.addition_embed_type == "image_hint":
843
+ # Kandinsky 2.2 - style
844
+ if (
845
+ "image_embeds" not in added_cond_kwargs
846
+ or "hint" not in added_cond_kwargs
847
+ ):
848
+ raise ValueError(
849
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
850
+ )
851
+ image_embs = added_cond_kwargs.get("image_embeds")
852
+ hint = added_cond_kwargs.get("hint")
853
+ aug_emb, hint = self.add_embedding(image_embs, hint)
854
+ sample = torch.cat([sample, hint], dim=1)
855
+
856
+ emb = emb + aug_emb if aug_emb is not None else emb
857
+
858
+ if self.time_embed_act is not None:
859
+ emb = self.time_embed_act(emb)
860
+
861
+ if (
862
+ self.encoder_hid_proj is not None
863
+ and self.config.encoder_hid_dim_type == "text_proj"
864
+ ):
865
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
866
+ elif (
867
+ self.encoder_hid_proj is not None
868
+ and self.config.encoder_hid_dim_type == "text_image_proj"
869
+ ):
870
+ # Kadinsky 2.1 - style
871
+ if "image_embeds" not in added_cond_kwargs:
872
+ raise ValueError(
873
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
874
+ )
875
+
876
+ image_embeds = added_cond_kwargs.get("image_embeds")
877
+ encoder_hidden_states = self.encoder_hid_proj(
878
+ encoder_hidden_states, image_embeds
879
+ )
880
+ elif (
881
+ self.encoder_hid_proj is not None
882
+ and self.config.encoder_hid_dim_type == "image_proj"
883
+ ):
884
+ # Kandinsky 2.2 - style
885
+ if "image_embeds" not in added_cond_kwargs:
886
+ raise ValueError(
887
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
888
+ )
889
+ image_embeds = added_cond_kwargs.get("image_embeds")
890
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
891
+ elif (
892
+ self.encoder_hid_proj is not None
893
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
894
+ ):
895
+ if "image_embeds" not in added_cond_kwargs:
896
+ raise ValueError(
897
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
898
+ )
899
+ image_embeds = added_cond_kwargs.get("image_embeds")
900
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
901
+ encoder_hidden_states.dtype
902
+ )
903
+ encoder_hidden_states = torch.cat(
904
+ [encoder_hidden_states, image_embeds], dim=1
905
+ )
906
+
907
+ # need_self_attn_block_embs
908
+ # 初始化
909
+ # 或在unet中运算中会不断 append self_attn_blocks_embs,用完需要清理,
910
+ if self.need_self_attn_block_embs:
911
+ self_attn_block_embs = [None] * self.self_attn_num
912
+ else:
913
+ self_attn_block_embs = None
914
+ # 2. pre-process
915
+ sample = self.conv_in(sample)
916
+ if self.print_idx == 0:
917
+ logger.debug(f"after conv in sample={sample.mean()}")
918
+ # 2.5 GLIGEN position net
919
+ if (
920
+ cross_attention_kwargs is not None
921
+ and cross_attention_kwargs.get("gligen", None) is not None
922
+ ):
923
+ cross_attention_kwargs = cross_attention_kwargs.copy()
924
+ gligen_args = cross_attention_kwargs.pop("gligen")
925
+ cross_attention_kwargs["gligen"] = {
926
+ "objs": self.position_net(**gligen_args)
927
+ }
928
+
929
+ # 3. down
930
+ lora_scale = (
931
+ cross_attention_kwargs.get("scale", 1.0)
932
+ if cross_attention_kwargs is not None
933
+ else 1.0
934
+ )
935
+ if USE_PEFT_BACKEND:
936
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
937
+ scale_lora_layers(self, lora_scale)
938
+
939
+ is_controlnet = (
940
+ mid_block_additional_residual is not None
941
+ and down_block_additional_residuals is not None
942
+ )
943
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
944
+ is_adapter = down_intrablock_additional_residuals is not None
945
+ # maintain backward compatibility for legacy usage, where
946
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
947
+ # but can only use one or the other
948
+ if (
949
+ not is_adapter
950
+ and mid_block_additional_residual is None
951
+ and down_block_additional_residuals is not None
952
+ ):
953
+ deprecate(
954
+ "T2I should not use down_block_additional_residuals",
955
+ "1.3.0",
956
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
957
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
958
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
959
+ standard_warn=False,
960
+ )
961
+ down_intrablock_additional_residuals = down_block_additional_residuals
962
+ is_adapter = True
963
+
964
+ down_block_res_samples = (sample,)
965
+ for i_downsample_block, downsample_block in enumerate(self.down_blocks):
966
+ if (
967
+ hasattr(downsample_block, "has_cross_attention")
968
+ and downsample_block.has_cross_attention
969
+ ):
970
+ # For t2i-adapter CrossAttnDownBlock2D
971
+ additional_residuals = {}
972
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
973
+ additional_residuals[
974
+ "additional_residuals"
975
+ ] = down_intrablock_additional_residuals.pop(0)
976
+ if self.print_idx == 0:
977
+ logger.debug(
978
+ f"downsample_block {i_downsample_block} sample={sample.mean()}"
979
+ )
980
+ sample, res_samples = downsample_block(
981
+ hidden_states=sample,
982
+ temb=emb,
983
+ encoder_hidden_states=encoder_hidden_states,
984
+ attention_mask=attention_mask,
985
+ cross_attention_kwargs=cross_attention_kwargs,
986
+ encoder_attention_mask=encoder_attention_mask,
987
+ **additional_residuals,
988
+ self_attn_block_embs=self_attn_block_embs,
989
+ )
990
+ else:
991
+ sample, res_samples = downsample_block(
992
+ hidden_states=sample,
993
+ temb=emb,
994
+ scale=lora_scale,
995
+ self_attn_block_embs=self_attn_block_embs,
996
+ )
997
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
998
+ sample += down_intrablock_additional_residuals.pop(0)
999
+
1000
+ down_block_res_samples += res_samples
1001
+
1002
+ if is_controlnet:
1003
+ new_down_block_res_samples = ()
1004
+
1005
+ for down_block_res_sample, down_block_additional_residual in zip(
1006
+ down_block_res_samples, down_block_additional_residuals
1007
+ ):
1008
+ down_block_res_sample = (
1009
+ down_block_res_sample + down_block_additional_residual
1010
+ )
1011
+ new_down_block_res_samples = new_down_block_res_samples + (
1012
+ down_block_res_sample,
1013
+ )
1014
+
1015
+ down_block_res_samples = new_down_block_res_samples
1016
+
1017
+ # update code start
1018
+ def reshape_return_emb(tmp_emb):
1019
+ if return_ndim == 4:
1020
+ return tmp_emb
1021
+ elif return_ndim == 5:
1022
+ return rearrange(tmp_emb, "(b t) c h w-> b c t h w", t=num_frames)
1023
+ else:
1024
+ raise ValueError(
1025
+ f"reshape_emb only support 4, 5 but given {return_ndim}"
1026
+ )
1027
+
1028
+ if self.need_block_embs:
1029
+ return_down_block_res_samples = [
1030
+ reshape_return_emb(tmp_emb) for tmp_emb in down_block_res_samples
1031
+ ]
1032
+ else:
1033
+ return_down_block_res_samples = None
1034
+ # update code end
1035
+
1036
+ # 4. mid
1037
+ if self.mid_block is not None:
1038
+ if (
1039
+ hasattr(self.mid_block, "has_cross_attention")
1040
+ and self.mid_block.has_cross_attention
1041
+ ):
1042
+ sample = self.mid_block(
1043
+ sample,
1044
+ emb,
1045
+ encoder_hidden_states=encoder_hidden_states,
1046
+ attention_mask=attention_mask,
1047
+ cross_attention_kwargs=cross_attention_kwargs,
1048
+ encoder_attention_mask=encoder_attention_mask,
1049
+ self_attn_block_embs=self_attn_block_embs,
1050
+ )
1051
+ else:
1052
+ sample = self.mid_block(sample, emb)
1053
+
1054
+ # To support T2I-Adapter-XL
1055
+ if (
1056
+ is_adapter
1057
+ and len(down_intrablock_additional_residuals) > 0
1058
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1059
+ ):
1060
+ sample += down_intrablock_additional_residuals.pop(0)
1061
+
1062
+ if is_controlnet:
1063
+ sample = sample + mid_block_additional_residual
1064
+
1065
+ if self.need_block_embs:
1066
+ return_mid_block_res_samples = reshape_return_emb(sample)
1067
+ logger.debug(
1068
+ f"return_mid_block_res_samples, is_leaf={return_mid_block_res_samples.is_leaf}, requires_grad={return_mid_block_res_samples.requires_grad}"
1069
+ )
1070
+ else:
1071
+ return_mid_block_res_samples = None
1072
+
1073
+ if self.up_blocks is not None:
1074
+ # update code end
1075
+
1076
+ # 5. up
1077
+ for i, upsample_block in enumerate(self.up_blocks):
1078
+ is_final_block = i == len(self.up_blocks) - 1
1079
+
1080
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1081
+ down_block_res_samples = down_block_res_samples[
1082
+ : -len(upsample_block.resnets)
1083
+ ]
1084
+
1085
+ # if we have not reached the final block and need to forward the
1086
+ # upsample size, we do it here
1087
+ if not is_final_block and forward_upsample_size:
1088
+ upsample_size = down_block_res_samples[-1].shape[2:]
1089
+
1090
+ if (
1091
+ hasattr(upsample_block, "has_cross_attention")
1092
+ and upsample_block.has_cross_attention
1093
+ ):
1094
+ sample = upsample_block(
1095
+ hidden_states=sample,
1096
+ temb=emb,
1097
+ res_hidden_states_tuple=res_samples,
1098
+ encoder_hidden_states=encoder_hidden_states,
1099
+ cross_attention_kwargs=cross_attention_kwargs,
1100
+ upsample_size=upsample_size,
1101
+ attention_mask=attention_mask,
1102
+ encoder_attention_mask=encoder_attention_mask,
1103
+ self_attn_block_embs=self_attn_block_embs,
1104
+ )
1105
+ else:
1106
+ sample = upsample_block(
1107
+ hidden_states=sample,
1108
+ temb=emb,
1109
+ res_hidden_states_tuple=res_samples,
1110
+ upsample_size=upsample_size,
1111
+ scale=lora_scale,
1112
+ self_attn_block_embs=self_attn_block_embs,
1113
+ )
1114
+
1115
+ # update code start
1116
+ if self.need_block_embs or self.need_self_attn_block_embs:
1117
+ if self_attn_block_embs is not None:
1118
+ self_attn_block_embs = [
1119
+ reshape_return_emb(tmp_emb=tmp_emb)
1120
+ for tmp_emb in self_attn_block_embs
1121
+ ]
1122
+ self.print_idx += 1
1123
+ return (
1124
+ return_down_block_res_samples,
1125
+ return_mid_block_res_samples,
1126
+ self_attn_block_embs,
1127
+ )
1128
+
1129
+ if not self.need_block_embs and not self.need_self_attn_block_embs:
1130
+ # 6. post-process
1131
+ if self.conv_norm_out:
1132
+ sample = self.conv_norm_out(sample)
1133
+ sample = self.conv_act(sample)
1134
+ sample = self.conv_out(sample)
1135
+
1136
+ if USE_PEFT_BACKEND:
1137
+ # remove `lora_scale` from each PEFT layer
1138
+ unscale_lora_layers(self, lora_scale)
1139
+ self.print_idx += 1
1140
+ if not return_dict:
1141
+ return (sample,)
1142
+
1143
+ return UNet2DConditionOutput(sample=sample)
1144
+
1145
+ def insert_spatial_self_attn_idx(self):
1146
+ attns, basic_transformers = self.spatial_self_attns
1147
+ self.self_attn_num = len(attns)
1148
+ for i, (name, layer) in enumerate(attns):
1149
+ logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}")
1150
+ if layer is not None:
1151
+ layer.spatial_self_attn_idx = i
1152
+ for i, (name, layer) in enumerate(basic_transformers):
1153
+ logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}")
1154
+ if layer is not None:
1155
+ layer.spatial_self_attn_idx = i
1156
+
1157
+ @property
1158
+ def spatial_self_attns(
1159
+ self,
1160
+ ) -> List[Tuple[str, Attention]]:
1161
+ attns, spatial_transformers = self.get_self_attns(
1162
+ include="attentions", exclude="temp_attentions"
1163
+ )
1164
+ attns = sorted(attns)
1165
+ spatial_transformers = sorted(spatial_transformers)
1166
+ return attns, spatial_transformers
1167
+
1168
+ def get_self_attns(
1169
+ self, include: str = None, exclude: str = None
1170
+ ) -> List[Tuple[str, Attention]]:
1171
+ r"""
1172
+ Returns:
1173
+ `dict` of attention attns: A dictionary containing all attention attns used in the model with
1174
+ indexed by its weight name.
1175
+ """
1176
+ # set recursively
1177
+ attns = []
1178
+ spatial_transformers = []
1179
+
1180
+ def fn_recursive_add_attns(
1181
+ name: str,
1182
+ module: torch.nn.Module,
1183
+ attns: List[Tuple[str, Attention]],
1184
+ spatial_transformers: List[Tuple[str, BasicTransformerBlock]],
1185
+ ):
1186
+ is_target = False
1187
+ if isinstance(module, BasicTransformerBlock) and hasattr(module, "attn1"):
1188
+ is_target = True
1189
+ if include is not None:
1190
+ is_target = include in name
1191
+ if exclude is not None:
1192
+ is_target = exclude not in name
1193
+ if is_target:
1194
+ attns.append([f"{name}.attn1", module.attn1])
1195
+ spatial_transformers.append([f"{name}", module])
1196
+ for sub_name, child in module.named_children():
1197
+ fn_recursive_add_attns(
1198
+ f"{name}.{sub_name}", child, attns, spatial_transformers
1199
+ )
1200
+
1201
+ return attns
1202
+
1203
+ for name, module in self.named_children():
1204
+ fn_recursive_add_attns(name, module, attns, spatial_transformers)
1205
+
1206
+ return attns, spatial_transformers
1207
+
1208
+
1209
+ class ReferenceNet3D(UNet3DConditionModel):
1210
+ """继承 UNet3DConditionModel, 用于提取中间emb用于后续作用。
1211
+ Inherit Unet3DConditionModel, used to extract the middle emb for subsequent actions.
1212
+ Args:
1213
+ UNet3DConditionModel (_type_): _description_
1214
+ """
1215
+
1216
+ pass
musev/models/referencenet_loader.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.modeling_utils import load_state_dict
31
+ from diffusers.utils import (
32
+ logging,
33
+ )
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+
36
+ from .referencenet import ReferenceNet2D
37
+ from .unet_loader import update_unet_with_sd
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ def load_referencenet(
44
+ sd_referencenet_model: Tuple[str, nn.Module],
45
+ sd_model: nn.Module = None,
46
+ need_self_attn_block_embs: bool = False,
47
+ need_block_embs: bool = False,
48
+ dtype: torch.dtype = torch.float16,
49
+ cross_attention_dim: int = 768,
50
+ subfolder: str = "unet",
51
+ ):
52
+ """
53
+ Loads the ReferenceNet model.
54
+
55
+ Args:
56
+ sd_referencenet_model (Tuple[str, nn.Module] or str): The pretrained ReferenceNet model or the path to the model.
57
+ sd_model (nn.Module, optional): The sd_model to update the ReferenceNet with. Defaults to None.
58
+ need_self_attn_block_embs (bool, optional): Whether to compute self-attention block embeddings. Defaults to False.
59
+ need_block_embs (bool, optional): Whether to compute block embeddings. Defaults to False.
60
+ dtype (torch.dtype, optional): The data type of the tensors. Defaults to torch.float16.
61
+ cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 768.
62
+ subfolder (str, optional): The subfolder of the model. Defaults to "unet".
63
+
64
+ Returns:
65
+ nn.Module: The loaded ReferenceNet model.
66
+ """
67
+
68
+ if isinstance(sd_referencenet_model, str):
69
+ referencenet = ReferenceNet2D.from_pretrained(
70
+ sd_referencenet_model,
71
+ subfolder=subfolder,
72
+ need_self_attn_block_embs=need_self_attn_block_embs,
73
+ need_block_embs=need_block_embs,
74
+ torch_dtype=dtype,
75
+ cross_attention_dim=cross_attention_dim,
76
+ )
77
+ elif isinstance(sd_referencenet_model, nn.Module):
78
+ referencenet = sd_referencenet_model
79
+ if sd_model is not None:
80
+ referencenet = update_unet_with_sd(referencenet, sd_model)
81
+ return referencenet
82
+
83
+
84
+ def load_referencenet_by_name(
85
+ model_name: str,
86
+ sd_referencenet_model: Tuple[str, nn.Module],
87
+ sd_model: nn.Module = None,
88
+ cross_attention_dim: int = 768,
89
+ dtype: torch.dtype = torch.float16,
90
+ ) -> nn.Module:
91
+ """通过模型名字 初始化 referencenet,载入预训练参数,
92
+ 如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义
93
+ init referencenet with model_name.
94
+ if you want to use pretrained model with simple name, you need to define it here.
95
+ Args:
96
+ model_name (str): _description_
97
+ sd_unet_model (Tuple[str, nn.Module]): _description_
98
+ sd_model (Tuple[str, nn.Module]): _description_
99
+ cross_attention_dim (int, optional): _description_. Defaults to 768.
100
+ dtype (torch.dtype, optional): _description_. Defaults to torch.float16.
101
+
102
+ Raises:
103
+ ValueError: _description_
104
+
105
+ Returns:
106
+ nn.Module: _description_
107
+ """
108
+ if model_name in [
109
+ "musev_referencenet",
110
+ ]:
111
+ unet = load_referencenet(
112
+ sd_referencenet_model=sd_referencenet_model,
113
+ sd_model=sd_model,
114
+ cross_attention_dim=cross_attention_dim,
115
+ dtype=dtype,
116
+ need_self_attn_block_embs=False,
117
+ need_block_embs=True,
118
+ subfolder="referencenet",
119
+ )
120
+ else:
121
+ raise ValueError(
122
+ f"unsupport model_name={model_name}, only support ReferenceNet_V0_block13, ReferenceNet_V1_block13, ReferenceNet_V2_block13, ReferenceNet_V0_sefattn16"
123
+ )
124
+ return unet
musev/models/resnet.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py
17
+ from __future__ import annotations
18
+
19
+ from functools import partial
20
+ from typing import Optional
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from einops import rearrange, repeat
26
+
27
+ from diffusers.models.resnet import TemporalConvLayer as DiffusersTemporalConvLayer
28
+ from ..data.data_util import batch_index_fill, batch_index_select
29
+ from . import Model_Register
30
+
31
+
32
+ @Model_Register.register
33
+ class TemporalConvLayer(nn.Module):
34
+ """
35
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
36
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ in_dim,
42
+ out_dim=None,
43
+ dropout=0.0,
44
+ keep_content_condition: bool = False,
45
+ femb_channels: Optional[int] = None,
46
+ need_temporal_weight: bool = True,
47
+ ):
48
+ super().__init__()
49
+ out_dim = out_dim or in_dim
50
+ self.in_dim = in_dim
51
+ self.out_dim = out_dim
52
+ self.keep_content_condition = keep_content_condition
53
+ self.femb_channels = femb_channels
54
+ self.need_temporal_weight = need_temporal_weight
55
+ # conv layers
56
+ self.conv1 = nn.Sequential(
57
+ nn.GroupNorm(32, in_dim),
58
+ nn.SiLU(),
59
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
60
+ )
61
+ self.conv2 = nn.Sequential(
62
+ nn.GroupNorm(32, out_dim),
63
+ nn.SiLU(),
64
+ nn.Dropout(dropout),
65
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
66
+ )
67
+ self.conv3 = nn.Sequential(
68
+ nn.GroupNorm(32, out_dim),
69
+ nn.SiLU(),
70
+ nn.Dropout(dropout),
71
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
72
+ )
73
+ self.conv4 = nn.Sequential(
74
+ nn.GroupNorm(32, out_dim),
75
+ nn.SiLU(),
76
+ nn.Dropout(dropout),
77
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
78
+ )
79
+
80
+ # zero out the last layer params,so the conv block is identity
81
+ # nn.init.zeros_(self.conv4[-1].weight)
82
+ # nn.init.zeros_(self.conv4[-1].bias)
83
+ self.temporal_weight = nn.Parameter(
84
+ torch.tensor(
85
+ [
86
+ 1e-5,
87
+ ]
88
+ )
89
+ ) # initialize parameter with 0
90
+ # zero out the last layer params,so the conv block is identity
91
+ nn.init.zeros_(self.conv4[-1].weight)
92
+ nn.init.zeros_(self.conv4[-1].bias)
93
+ self.skip_temporal_layers = False # Whether to skip temporal layer
94
+
95
+ def forward(
96
+ self,
97
+ hidden_states,
98
+ num_frames=1,
99
+ sample_index: torch.LongTensor = None,
100
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
101
+ femb: torch.Tensor = None,
102
+ ):
103
+ if self.skip_temporal_layers is True:
104
+ return hidden_states
105
+ hidden_states_dtype = hidden_states.dtype
106
+ hidden_states = rearrange(
107
+ hidden_states, "(b t) c h w -> b c t h w", t=num_frames
108
+ )
109
+ identity = hidden_states
110
+ hidden_states = self.conv1(hidden_states)
111
+ hidden_states = self.conv2(hidden_states)
112
+ hidden_states = self.conv3(hidden_states)
113
+ hidden_states = self.conv4(hidden_states)
114
+ # 保留condition对应的frames,便于保持前序内容帧,提升一致性
115
+ if self.keep_content_condition:
116
+ mask = torch.ones_like(hidden_states, device=hidden_states.device)
117
+ mask = batch_index_fill(
118
+ mask, dim=2, index=vision_conditon_frames_sample_index, value=0
119
+ )
120
+ if self.need_temporal_weight:
121
+ hidden_states = (
122
+ identity + torch.abs(self.temporal_weight) * mask * hidden_states
123
+ )
124
+ else:
125
+ hidden_states = identity + mask * hidden_states
126
+ else:
127
+ if self.need_temporal_weight:
128
+ hidden_states = (
129
+ identity + torch.abs(self.temporal_weight) * hidden_states
130
+ )
131
+ else:
132
+ hidden_states = identity + hidden_states
133
+ hidden_states = rearrange(hidden_states, " b c t h w -> (b t) c h w")
134
+ hidden_states = hidden_states.to(dtype=hidden_states_dtype)
135
+ return hidden_states
musev/models/super_model.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ from typing import Any, Dict, Tuple, Union, Optional
6
+ from einops import rearrange, repeat
7
+ from torch import nn
8
+ import torch
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict
12
+
13
+ from ..data.data_util import align_repeat_tensor_single_dim
14
+
15
+ from .unet_3d_condition import UNet3DConditionModel
16
+ from .referencenet import ReferenceNet2D
17
+ from ip_adapter.ip_adapter import ImageProjModel
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class SuperUNet3DConditionModel(nn.Module):
23
+ """封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。
24
+ 主要作用
25
+ 1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些;
26
+ 2. 便于 accelerator 的分布式训练;
27
+
28
+ wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj
29
+ 1. support controlnet, referencenet, etc.
30
+ 2. support accelerator distributed training
31
+ """
32
+
33
+ _supports_gradient_checkpointing = True
34
+ print_idx = 0
35
+
36
+ # @register_to_config
37
+ def __init__(
38
+ self,
39
+ unet: nn.Module,
40
+ referencenet: nn.Module = None,
41
+ controlnet: nn.Module = None,
42
+ vae: nn.Module = None,
43
+ text_encoder: nn.Module = None,
44
+ tokenizer: nn.Module = None,
45
+ text_emb_extractor: nn.Module = None,
46
+ clip_vision_extractor: nn.Module = None,
47
+ ip_adapter_image_proj: nn.Module = None,
48
+ ) -> None:
49
+ """_summary_
50
+
51
+ Args:
52
+ unet (nn.Module): _description_
53
+ referencenet (nn.Module, optional): _description_. Defaults to None.
54
+ controlnet (nn.Module, optional): _description_. Defaults to None.
55
+ vae (nn.Module, optional): _description_. Defaults to None.
56
+ text_encoder (nn.Module, optional): _description_. Defaults to None.
57
+ tokenizer (nn.Module, optional): _description_. Defaults to None.
58
+ text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None.
59
+ clip_vision_extractor (nn.Module, optional): _description_. Defaults to None.
60
+ """
61
+ super().__init__()
62
+ self.unet = unet
63
+ self.referencenet = referencenet
64
+ self.controlnet = controlnet
65
+ self.vae = vae
66
+ self.text_encoder = text_encoder
67
+ self.tokenizer = tokenizer
68
+ self.text_emb_extractor = text_emb_extractor
69
+ self.clip_vision_extractor = clip_vision_extractor
70
+ self.ip_adapter_image_proj = ip_adapter_image_proj
71
+
72
+ def forward(
73
+ self,
74
+ unet_params: Dict,
75
+ encoder_hidden_states: torch.Tensor,
76
+ referencenet_params: Dict = None,
77
+ controlnet_params: Dict = None,
78
+ controlnet_scale: float = 1.0,
79
+ vision_clip_emb: Union[torch.Tensor, None] = None,
80
+ prompt_only_use_image_prompt: bool = False,
81
+ ):
82
+ """_summary_
83
+
84
+ Args:
85
+ unet_params (Dict): _description_
86
+ encoder_hidden_states (torch.Tensor): b t n d
87
+ referencenet_params (Dict, optional): _description_. Defaults to None.
88
+ controlnet_params (Dict, optional): _description_. Defaults to None.
89
+ controlnet_scale (float, optional): _description_. Defaults to 1.0.
90
+ vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None.
91
+ prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False.
92
+
93
+ Returns:
94
+ _type_: _description_
95
+ """
96
+ batch_size = unet_params["sample"].shape[0]
97
+ time_size = unet_params["sample"].shape[2]
98
+
99
+ # ip_adapter_cross_attn, prepare image prompt
100
+ if vision_clip_emb is not None:
101
+ # b t n d -> b t n d
102
+ if self.print_idx == 0:
103
+ logger.debug(
104
+ f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
105
+ )
106
+ if vision_clip_emb.ndim == 3:
107
+ vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d")
108
+ if self.ip_adapter_image_proj is not None:
109
+ vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d")
110
+ vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb)
111
+ if self.print_idx == 0:
112
+ logger.debug(
113
+ f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
114
+ )
115
+ if vision_clip_emb.ndim == 2:
116
+ vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d")
117
+ vision_clip_emb = rearrange(
118
+ vision_clip_emb, "(b t) n d -> b t n d", b=batch_size
119
+ )
120
+ vision_clip_emb = align_repeat_tensor_single_dim(
121
+ vision_clip_emb, target_length=time_size, dim=1
122
+ )
123
+ if self.print_idx == 0:
124
+ logger.debug(
125
+ f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
126
+ )
127
+
128
+ if vision_clip_emb is None and encoder_hidden_states is not None:
129
+ vision_clip_emb = encoder_hidden_states
130
+ if vision_clip_emb is not None and encoder_hidden_states is None:
131
+ encoder_hidden_states = vision_clip_emb
132
+ # 当 prompt_only_use_image_prompt 为True时,
133
+ # 1. referencenet 都使用 vision_clip_emb
134
+ # 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新
135
+ # 3. controlnet 当前使用 text_prompt
136
+
137
+ # when prompt_only_use_image_prompt True,
138
+ # 1. referencenet use vision_clip_emb
139
+ # 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update
140
+ # 3. controlnet use text_prompt
141
+
142
+ # extract referencenet emb
143
+ if self.referencenet is not None and referencenet_params is not None:
144
+ referencenet_encoder_hidden_states = align_repeat_tensor_single_dim(
145
+ vision_clip_emb,
146
+ target_length=referencenet_params["num_frames"],
147
+ dim=1,
148
+ )
149
+ referencenet_params["encoder_hidden_states"] = rearrange(
150
+ referencenet_encoder_hidden_states, "b t n d->(b t) n d"
151
+ )
152
+ referencenet_out = self.referencenet(**referencenet_params)
153
+ (
154
+ down_block_refer_embs,
155
+ mid_block_refer_emb,
156
+ refer_self_attn_emb,
157
+ ) = referencenet_out
158
+ if down_block_refer_embs is not None:
159
+ if self.print_idx == 0:
160
+ logger.debug(
161
+ f"len(down_block_refer_embs)={len(down_block_refer_embs)}"
162
+ )
163
+ for i, down_emb in enumerate(down_block_refer_embs):
164
+ if self.print_idx == 0:
165
+ logger.debug(
166
+ f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}"
167
+ )
168
+ else:
169
+ if self.print_idx == 0:
170
+ logger.debug(f"down_block_refer_embs is None")
171
+ if mid_block_refer_emb is not None:
172
+ if self.print_idx == 0:
173
+ logger.debug(
174
+ f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}"
175
+ )
176
+ else:
177
+ if self.print_idx == 0:
178
+ logger.debug(f"mid_block_refer_emb is None")
179
+ if refer_self_attn_emb is not None:
180
+ if self.print_idx == 0:
181
+ logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}")
182
+ for i, self_attn_emb in enumerate(refer_self_attn_emb):
183
+ if self.print_idx == 0:
184
+ logger.debug(
185
+ f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}"
186
+ )
187
+ else:
188
+ if self.print_idx == 0:
189
+ logger.debug(f"refer_self_attn_emb is None")
190
+ else:
191
+ down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = (
192
+ None,
193
+ None,
194
+ None,
195
+ )
196
+
197
+ # extract controlnet emb
198
+ if self.controlnet is not None and controlnet_params is not None:
199
+ controlnet_encoder_hidden_states = align_repeat_tensor_single_dim(
200
+ encoder_hidden_states,
201
+ target_length=unet_params["sample"].shape[2],
202
+ dim=1,
203
+ )
204
+ controlnet_params["encoder_hidden_states"] = rearrange(
205
+ controlnet_encoder_hidden_states, " b t n d -> (b t) n d"
206
+ )
207
+ (
208
+ down_block_additional_residuals,
209
+ mid_block_additional_residual,
210
+ ) = self.controlnet(**controlnet_params)
211
+ if controlnet_scale != 1.0:
212
+ down_block_additional_residuals = [
213
+ x * controlnet_scale for x in down_block_additional_residuals
214
+ ]
215
+ mid_block_additional_residual = (
216
+ mid_block_additional_residual * controlnet_scale
217
+ )
218
+ for i, down_block_additional_residual in enumerate(
219
+ down_block_additional_residuals
220
+ ):
221
+ if self.print_idx == 0:
222
+ logger.debug(
223
+ f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}"
224
+ )
225
+
226
+ if self.print_idx == 0:
227
+ logger.debug(
228
+ f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}"
229
+ )
230
+ else:
231
+ down_block_additional_residuals = None
232
+ mid_block_additional_residual = None
233
+
234
+ if prompt_only_use_image_prompt and vision_clip_emb is not None:
235
+ encoder_hidden_states = vision_clip_emb
236
+
237
+ # run unet
238
+ out = self.unet(
239
+ **unet_params,
240
+ down_block_refer_embs=down_block_refer_embs,
241
+ mid_block_refer_emb=mid_block_refer_emb,
242
+ refer_self_attn_emb=refer_self_attn_emb,
243
+ down_block_additional_residuals=down_block_additional_residuals,
244
+ mid_block_additional_residual=mid_block_additional_residual,
245
+ encoder_hidden_states=encoder_hidden_states,
246
+ vision_clip_emb=vision_clip_emb,
247
+ )
248
+ self.print_idx += 1
249
+ return out
250
+
251
+ def _set_gradient_checkpointing(self, module, value=False):
252
+ if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)):
253
+ module.gradient_checkpointing = value
musev/models/temporal_transformer.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/transformer_temporal.py
16
+ from __future__ import annotations
17
+ from copy import deepcopy
18
+ from dataclasses import dataclass
19
+ from typing import List, Literal, Optional
20
+ import logging
21
+
22
+ import torch
23
+ from torch import nn
24
+ from einops import rearrange, repeat
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.utils import BaseOutput
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.models.transformer_temporal import (
30
+ TransformerTemporalModelOutput,
31
+ TransformerTemporalModel as DiffusersTransformerTemporalModel,
32
+ )
33
+ from diffusers.models.attention_processor import AttnProcessor
34
+
35
+ from mmcm.utils.gpu_util import get_gpu_status
36
+ from ..data.data_util import (
37
+ batch_concat_two_tensor_with_index,
38
+ batch_index_fill,
39
+ batch_index_select,
40
+ concat_two_tensor,
41
+ align_repeat_tensor_single_dim,
42
+ )
43
+ from ..utils.attention_util import generate_sparse_causcal_attn_mask
44
+ from .attention import BasicTransformerBlock
45
+ from .attention_processor import (
46
+ BaseIPAttnProcessor,
47
+ )
48
+ from . import Model_Register
49
+
50
+ # https://github.com/facebookresearch/xformers/issues/845
51
+ # 输入bs*n_frames*w*h太高,xformers报错。因此将transformer_temporal的allow_xformers均关掉
52
+ # if bs*n_frames*w*h to large, xformers will raise error. So we close the allow_xformers in transformer_temporal
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ @Model_Register.register
57
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
58
+ """
59
+ Transformer model for video-like data.
60
+
61
+ Parameters:
62
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
63
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
64
+ in_channels (`int`, *optional*):
65
+ Pass if the input is continuous. The number of channels in the input and output.
66
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
67
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
68
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
69
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
70
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
71
+ `ImagePositionalEmbeddings`.
72
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
73
+ attention_bias (`bool`, *optional*):
74
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
75
+ double_self_attention (`bool`, *optional*):
76
+ Configure if each TransformerBlock should contain two self-attention layers
77
+ """
78
+
79
+ @register_to_config
80
+ def __init__(
81
+ self,
82
+ num_attention_heads: int = 16,
83
+ attention_head_dim: int = 88,
84
+ in_channels: Optional[int] = None,
85
+ out_channels: Optional[int] = None,
86
+ num_layers: int = 1,
87
+ femb_channels: Optional[int] = None,
88
+ dropout: float = 0.0,
89
+ norm_num_groups: int = 32,
90
+ cross_attention_dim: Optional[int] = None,
91
+ attention_bias: bool = False,
92
+ sample_size: Optional[int] = None,
93
+ activation_fn: str = "geglu",
94
+ norm_elementwise_affine: bool = True,
95
+ double_self_attention: bool = True,
96
+ allow_xformers: bool = False,
97
+ only_cross_attention: bool = False,
98
+ keep_content_condition: bool = False,
99
+ need_spatial_position_emb: bool = False,
100
+ need_temporal_weight: bool = True,
101
+ self_attn_mask: str = None,
102
+ # TODO: 运行参数,有待改到forward里面去
103
+ # TODO: running parameters, need to be moved to forward
104
+ image_scale: float = 1.0,
105
+ processor: AttnProcessor | None = None,
106
+ remove_femb_non_linear: bool = False,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.num_attention_heads = num_attention_heads
111
+ self.attention_head_dim = attention_head_dim
112
+
113
+ inner_dim = num_attention_heads * attention_head_dim
114
+ self.inner_dim = inner_dim
115
+ self.in_channels = in_channels
116
+
117
+ self.norm = torch.nn.GroupNorm(
118
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
119
+ )
120
+
121
+ self.proj_in = nn.Linear(in_channels, inner_dim)
122
+
123
+ # 2. Define temporal positional embedding
124
+ self.frame_emb_proj = torch.nn.Linear(femb_channels, inner_dim)
125
+ self.remove_femb_non_linear = remove_femb_non_linear
126
+ if not remove_femb_non_linear:
127
+ self.nonlinearity = nn.SiLU()
128
+
129
+ # spatial_position_emb 使用femb_的参数配置
130
+ self.need_spatial_position_emb = need_spatial_position_emb
131
+ if need_spatial_position_emb:
132
+ self.spatial_position_emb_proj = torch.nn.Linear(femb_channels, inner_dim)
133
+ # 3. Define transformers blocks
134
+ # TODO: 该实现方式不好,待优化
135
+ # TODO: bad implementation, need to be optimized
136
+ self.need_ipadapter = False
137
+ self.cross_attn_temporal_cond = False
138
+ self.allow_xformers = allow_xformers
139
+ if processor is not None and isinstance(processor, BaseIPAttnProcessor):
140
+ self.cross_attn_temporal_cond = True
141
+ self.allow_xformers = False
142
+ if "NonParam" not in processor.__class__.__name__:
143
+ self.need_ipadapter = True
144
+
145
+ self.transformer_blocks = nn.ModuleList(
146
+ [
147
+ BasicTransformerBlock(
148
+ inner_dim,
149
+ num_attention_heads,
150
+ attention_head_dim,
151
+ dropout=dropout,
152
+ cross_attention_dim=cross_attention_dim,
153
+ activation_fn=activation_fn,
154
+ attention_bias=attention_bias,
155
+ double_self_attention=double_self_attention,
156
+ norm_elementwise_affine=norm_elementwise_affine,
157
+ allow_xformers=allow_xformers,
158
+ only_cross_attention=only_cross_attention,
159
+ cross_attn_temporal_cond=self.need_ipadapter,
160
+ image_scale=image_scale,
161
+ processor=processor,
162
+ )
163
+ for d in range(num_layers)
164
+ ]
165
+ )
166
+
167
+ self.proj_out = nn.Linear(inner_dim, in_channels)
168
+
169
+ self.need_temporal_weight = need_temporal_weight
170
+ if need_temporal_weight:
171
+ self.temporal_weight = nn.Parameter(
172
+ torch.tensor(
173
+ [
174
+ 1e-5,
175
+ ]
176
+ )
177
+ ) # initialize parameter with 0
178
+ self.skip_temporal_layers = False # Whether to skip temporal layer
179
+ self.keep_content_condition = keep_content_condition
180
+ self.self_attn_mask = self_attn_mask
181
+ self.only_cross_attention = only_cross_attention
182
+ self.double_self_attention = double_self_attention
183
+ self.cross_attention_dim = cross_attention_dim
184
+ self.image_scale = image_scale
185
+ # zero out the last layer params,so the conv block is identity
186
+ nn.init.zeros_(self.proj_out.weight)
187
+ nn.init.zeros_(self.proj_out.bias)
188
+
189
+ def forward(
190
+ self,
191
+ hidden_states,
192
+ femb,
193
+ encoder_hidden_states=None,
194
+ timestep=None,
195
+ class_labels=None,
196
+ num_frames=1,
197
+ cross_attention_kwargs=None,
198
+ sample_index: torch.LongTensor = None,
199
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
200
+ spatial_position_emb: torch.Tensor = None,
201
+ return_dict: bool = True,
202
+ ):
203
+ """
204
+ Args:
205
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
206
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
207
+ hidden_states
208
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
209
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
210
+ self-attention.
211
+ timestep ( `torch.long`, *optional*):
212
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
213
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
214
+ Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
215
+ conditioning.
216
+ return_dict (`bool`, *optional*, defaults to `True`):
217
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
218
+
219
+ Returns:
220
+ [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
221
+ [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
222
+ When returning a tuple, the first element is the sample tensor.
223
+ """
224
+ if self.skip_temporal_layers is True:
225
+ if not return_dict:
226
+ return (hidden_states,)
227
+
228
+ return TransformerTemporalModelOutput(sample=hidden_states)
229
+
230
+ # 1. Input
231
+ batch_frames, channel, height, width = hidden_states.shape
232
+ batch_size = batch_frames // num_frames
233
+
234
+ hidden_states = rearrange(
235
+ hidden_states, "(b t) c h w -> b c t h w", b=batch_size
236
+ )
237
+ residual = hidden_states
238
+
239
+ hidden_states = self.norm(hidden_states)
240
+
241
+ hidden_states = rearrange(hidden_states, "b c t h w -> (b h w) t c")
242
+
243
+ hidden_states = self.proj_in(hidden_states)
244
+
245
+ # 2 Positional embedding
246
+ # adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py#L574
247
+ if not self.remove_femb_non_linear:
248
+ femb = self.nonlinearity(femb)
249
+ femb = self.frame_emb_proj(femb)
250
+ femb = align_repeat_tensor_single_dim(femb, hidden_states.shape[0], dim=0)
251
+ hidden_states = hidden_states + femb
252
+
253
+ # 3. Blocks
254
+ if (
255
+ (self.only_cross_attention or not self.double_self_attention)
256
+ and self.cross_attention_dim is not None
257
+ and encoder_hidden_states is not None
258
+ ):
259
+ encoder_hidden_states = align_repeat_tensor_single_dim(
260
+ encoder_hidden_states,
261
+ hidden_states.shape[0],
262
+ dim=0,
263
+ n_src_base_length=batch_size,
264
+ )
265
+
266
+ for i, block in enumerate(self.transformer_blocks):
267
+ hidden_states = block(
268
+ hidden_states,
269
+ encoder_hidden_states=encoder_hidden_states,
270
+ timestep=timestep,
271
+ cross_attention_kwargs=cross_attention_kwargs,
272
+ class_labels=class_labels,
273
+ )
274
+
275
+ # 4. Output
276
+ hidden_states = self.proj_out(hidden_states)
277
+ hidden_states = rearrange(
278
+ hidden_states, "(b h w) t c -> b c t h w", b=batch_size, h=height, w=width
279
+ ).contiguous()
280
+
281
+ # 保留condition对应的frames,便于保持前序内容帧,提升一致性
282
+ # keep the frames corresponding to the condition to maintain the previous content frames and improve consistency
283
+ if (
284
+ vision_conditon_frames_sample_index is not None
285
+ and self.keep_content_condition
286
+ ):
287
+ mask = torch.ones_like(hidden_states, device=hidden_states.device)
288
+ mask = batch_index_fill(
289
+ mask, dim=2, index=vision_conditon_frames_sample_index, value=0
290
+ )
291
+ if self.need_temporal_weight:
292
+ output = (
293
+ residual + torch.abs(self.temporal_weight) * mask * hidden_states
294
+ )
295
+ else:
296
+ output = residual + mask * hidden_states
297
+ else:
298
+ if self.need_temporal_weight:
299
+ output = residual + torch.abs(self.temporal_weight) * hidden_states
300
+ else:
301
+ output = residual + mask * hidden_states
302
+
303
+ # output = torch.abs(self.temporal_weight) * hidden_states + residual
304
+ output = rearrange(output, "b c t h w -> (b t) c h w")
305
+ if not return_dict:
306
+ return (output,)
307
+
308
+ return TransformerTemporalModelOutput(sample=output)
musev/models/text_model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ from torch import nn
3
+
4
+
5
+ class TextEmbExtractor(nn.Module):
6
+ def __init__(self, tokenizer, text_encoder) -> None:
7
+ super(TextEmbExtractor, self).__init__()
8
+ self.tokenizer = tokenizer
9
+ self.text_encoder = text_encoder
10
+
11
+ def forward(
12
+ self,
13
+ texts,
14
+ text_params: Dict = None,
15
+ ):
16
+ if text_params is None:
17
+ text_params = {}
18
+ special_prompt_input = self.tokenizer(
19
+ texts,
20
+ max_length=self.tokenizer.model_max_length,
21
+ padding="max_length",
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ )
25
+ if (
26
+ hasattr(self.text_encoder.config, "use_attention_mask")
27
+ and self.text_encoder.config.use_attention_mask
28
+ ):
29
+ attention_mask = special_prompt_input.attention_mask.to(
30
+ self.text_encoder.device
31
+ )
32
+ else:
33
+ attention_mask = None
34
+
35
+ embeddings = self.text_encoder(
36
+ special_prompt_input.input_ids.to(self.text_encoder.device),
37
+ attention_mask=attention_mask,
38
+ **text_params
39
+ )
40
+ return embeddings
musev/models/transformer_2d.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from __future__ import annotations
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Literal, Optional
17
+ import logging
18
+
19
+ from einops import rearrange
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch import nn
24
+
25
+ from diffusers.models.transformer_2d import (
26
+ Transformer2DModelOutput,
27
+ Transformer2DModel as DiffusersTransformer2DModel,
28
+ )
29
+
30
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
31
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
32
+ from diffusers.utils import BaseOutput, deprecate
33
+ from diffusers.models.attention import (
34
+ BasicTransformerBlock as DiffusersBasicTransformerBlock,
35
+ )
36
+ from diffusers.models.embeddings import PatchEmbed
37
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
38
+ from diffusers.models.modeling_utils import ModelMixin
39
+ from diffusers.utils.constants import USE_PEFT_BACKEND
40
+
41
+ from .attention import BasicTransformerBlock
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # 本部分 与 diffusers/models/transformer_2d.py 几乎一样
46
+ # 更新部分
47
+ # 1. 替换自定义 BasicTransformerBlock 类
48
+ # 2. 在forward 里增加了 self_attn_block_embs 用于 提取 self_attn 中的emb
49
+
50
+ # this module is same as diffusers/models/transformer_2d.py. The update part is
51
+ # 1 redefine BasicTransformerBlock
52
+ # 2. add self_attn_block_embs in forward to extract emb from self_attn
53
+
54
+
55
+ class Transformer2DModel(DiffusersTransformer2DModel):
56
+ """
57
+ A 2D Transformer model for image-like data.
58
+
59
+ Parameters:
60
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
61
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
62
+ in_channels (`int`, *optional*):
63
+ The number of channels in the input and output (specify if the input is **continuous**).
64
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
65
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
66
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
67
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
68
+ This is fixed during training since it is used to learn a number of position embeddings.
69
+ num_vector_embeds (`int`, *optional*):
70
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
71
+ Includes the class for the masked latent pixel.
72
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
73
+ num_embeds_ada_norm ( `int`, *optional*):
74
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
75
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
76
+ added to the hidden states.
77
+
78
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
79
+ attention_bias (`bool`, *optional*):
80
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
81
+ """
82
+
83
+ @register_to_config
84
+ def __init__(
85
+ self,
86
+ num_attention_heads: int = 16,
87
+ attention_head_dim: int = 88,
88
+ in_channels: int | None = None,
89
+ out_channels: int | None = None,
90
+ num_layers: int = 1,
91
+ dropout: float = 0,
92
+ norm_num_groups: int = 32,
93
+ cross_attention_dim: int | None = None,
94
+ attention_bias: bool = False,
95
+ sample_size: int | None = None,
96
+ num_vector_embeds: int | None = None,
97
+ patch_size: int | None = None,
98
+ activation_fn: str = "geglu",
99
+ num_embeds_ada_norm: int | None = None,
100
+ use_linear_projection: bool = False,
101
+ only_cross_attention: bool = False,
102
+ double_self_attention: bool = False,
103
+ upcast_attention: bool = False,
104
+ norm_type: str = "layer_norm",
105
+ norm_elementwise_affine: bool = True,
106
+ attention_type: str = "default",
107
+ cross_attn_temporal_cond: bool = False,
108
+ ip_adapter_cross_attn: bool = False,
109
+ need_t2i_facein: bool = False,
110
+ need_t2i_ip_adapter_face: bool = False,
111
+ image_scale: float = 1.0,
112
+ ):
113
+ super().__init__(
114
+ num_attention_heads,
115
+ attention_head_dim,
116
+ in_channels,
117
+ out_channels,
118
+ num_layers,
119
+ dropout,
120
+ norm_num_groups,
121
+ cross_attention_dim,
122
+ attention_bias,
123
+ sample_size,
124
+ num_vector_embeds,
125
+ patch_size,
126
+ activation_fn,
127
+ num_embeds_ada_norm,
128
+ use_linear_projection,
129
+ only_cross_attention,
130
+ double_self_attention,
131
+ upcast_attention,
132
+ norm_type,
133
+ norm_elementwise_affine,
134
+ attention_type,
135
+ )
136
+ inner_dim = num_attention_heads * attention_head_dim
137
+ self.transformer_blocks = nn.ModuleList(
138
+ [
139
+ BasicTransformerBlock(
140
+ inner_dim,
141
+ num_attention_heads,
142
+ attention_head_dim,
143
+ dropout=dropout,
144
+ cross_attention_dim=cross_attention_dim,
145
+ activation_fn=activation_fn,
146
+ num_embeds_ada_norm=num_embeds_ada_norm,
147
+ attention_bias=attention_bias,
148
+ only_cross_attention=only_cross_attention,
149
+ double_self_attention=double_self_attention,
150
+ upcast_attention=upcast_attention,
151
+ norm_type=norm_type,
152
+ norm_elementwise_affine=norm_elementwise_affine,
153
+ attention_type=attention_type,
154
+ cross_attn_temporal_cond=cross_attn_temporal_cond,
155
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
156
+ need_t2i_facein=need_t2i_facein,
157
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
158
+ image_scale=image_scale,
159
+ )
160
+ for d in range(num_layers)
161
+ ]
162
+ )
163
+ self.num_layers = num_layers
164
+ self.cross_attn_temporal_cond = cross_attn_temporal_cond
165
+ self.ip_adapter_cross_attn = ip_adapter_cross_attn
166
+
167
+ self.need_t2i_facein = need_t2i_facein
168
+ self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
169
+ self.image_scale = image_scale
170
+ self.print_idx = 0
171
+
172
+ def forward(
173
+ self,
174
+ hidden_states: torch.Tensor,
175
+ encoder_hidden_states: Optional[torch.Tensor] = None,
176
+ timestep: Optional[torch.LongTensor] = None,
177
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
178
+ class_labels: Optional[torch.LongTensor] = None,
179
+ cross_attention_kwargs: Dict[str, Any] = None,
180
+ attention_mask: Optional[torch.Tensor] = None,
181
+ encoder_attention_mask: Optional[torch.Tensor] = None,
182
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
183
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
184
+ return_dict: bool = True,
185
+ ):
186
+ """
187
+ The [`Transformer2DModel`] forward method.
188
+
189
+ Args:
190
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
191
+ Input `hidden_states`.
192
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
193
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
194
+ self-attention.
195
+ timestep ( `torch.LongTensor`, *optional*):
196
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
197
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
198
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
199
+ `AdaLayerZeroNorm`.
200
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
201
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
202
+ `self.processor` in
203
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
204
+ attention_mask ( `torch.Tensor`, *optional*):
205
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
206
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
207
+ negative values to the attention scores corresponding to "discard" tokens.
208
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
209
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
210
+
211
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
212
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
213
+
214
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
215
+ above. This bias will be added to the cross-attention scores.
216
+ return_dict (`bool`, *optional*, defaults to `True`):
217
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
218
+ tuple.
219
+
220
+ Returns:
221
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
222
+ `tuple` where the first element is the sample tensor.
223
+ """
224
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
225
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
226
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
227
+ # expects mask of shape:
228
+ # [batch, key_tokens]
229
+ # adds singleton query_tokens dimension:
230
+ # [batch, 1, key_tokens]
231
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
232
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
233
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
234
+ if attention_mask is not None and attention_mask.ndim == 2:
235
+ # assume that mask is expressed as:
236
+ # (1 = keep, 0 = discard)
237
+ # convert mask into a bias that can be added to attention scores:
238
+ # (keep = +0, discard = -10000.0)
239
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
240
+ attention_mask = attention_mask.unsqueeze(1)
241
+
242
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
243
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
244
+ encoder_attention_mask = (
245
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
246
+ ) * -10000.0
247
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
248
+
249
+ # Retrieve lora scale.
250
+ lora_scale = (
251
+ cross_attention_kwargs.get("scale", 1.0)
252
+ if cross_attention_kwargs is not None
253
+ else 1.0
254
+ )
255
+
256
+ # 1. Input
257
+ if self.is_input_continuous:
258
+ batch, _, height, width = hidden_states.shape
259
+ residual = hidden_states
260
+
261
+ hidden_states = self.norm(hidden_states)
262
+ if not self.use_linear_projection:
263
+ hidden_states = (
264
+ self.proj_in(hidden_states, scale=lora_scale)
265
+ if not USE_PEFT_BACKEND
266
+ else self.proj_in(hidden_states)
267
+ )
268
+ inner_dim = hidden_states.shape[1]
269
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
270
+ batch, height * width, inner_dim
271
+ )
272
+ else:
273
+ inner_dim = hidden_states.shape[1]
274
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
275
+ batch, height * width, inner_dim
276
+ )
277
+ hidden_states = (
278
+ self.proj_in(hidden_states, scale=lora_scale)
279
+ if not USE_PEFT_BACKEND
280
+ else self.proj_in(hidden_states)
281
+ )
282
+
283
+ elif self.is_input_vectorized:
284
+ hidden_states = self.latent_image_embedding(hidden_states)
285
+ elif self.is_input_patches:
286
+ height, width = (
287
+ hidden_states.shape[-2] // self.patch_size,
288
+ hidden_states.shape[-1] // self.patch_size,
289
+ )
290
+ hidden_states = self.pos_embed(hidden_states)
291
+
292
+ if self.adaln_single is not None:
293
+ if self.use_additional_conditions and added_cond_kwargs is None:
294
+ raise ValueError(
295
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
296
+ )
297
+ batch_size = hidden_states.shape[0]
298
+ timestep, embedded_timestep = self.adaln_single(
299
+ timestep,
300
+ added_cond_kwargs,
301
+ batch_size=batch_size,
302
+ hidden_dtype=hidden_states.dtype,
303
+ )
304
+
305
+ # 2. Blocks
306
+ if self.caption_projection is not None:
307
+ batch_size = hidden_states.shape[0]
308
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
309
+ encoder_hidden_states = encoder_hidden_states.view(
310
+ batch_size, -1, hidden_states.shape[-1]
311
+ )
312
+
313
+ for block in self.transformer_blocks:
314
+ if self.training and self.gradient_checkpointing:
315
+ hidden_states = torch.utils.checkpoint.checkpoint(
316
+ block,
317
+ hidden_states,
318
+ attention_mask,
319
+ encoder_hidden_states,
320
+ encoder_attention_mask,
321
+ timestep,
322
+ cross_attention_kwargs,
323
+ class_labels,
324
+ self_attn_block_embs,
325
+ self_attn_block_embs_mode,
326
+ use_reentrant=False,
327
+ )
328
+ else:
329
+ hidden_states = block(
330
+ hidden_states,
331
+ attention_mask=attention_mask,
332
+ encoder_hidden_states=encoder_hidden_states,
333
+ encoder_attention_mask=encoder_attention_mask,
334
+ timestep=timestep,
335
+ cross_attention_kwargs=cross_attention_kwargs,
336
+ class_labels=class_labels,
337
+ self_attn_block_embs=self_attn_block_embs,
338
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
339
+ )
340
+ # 将 转换 self_attn_emb的尺寸
341
+ if (
342
+ self_attn_block_embs is not None
343
+ and self_attn_block_embs_mode.lower() == "write"
344
+ ):
345
+ self_attn_idx = block.spatial_self_attn_idx
346
+ if self.print_idx == 0:
347
+ logger.debug(
348
+ f"self_attn_block_embs, num={len(self_attn_block_embs)}, before, shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}"
349
+ )
350
+ self_attn_block_embs[self_attn_idx] = rearrange(
351
+ self_attn_block_embs[self_attn_idx],
352
+ "bt (h w) c->bt c h w",
353
+ h=height,
354
+ w=width,
355
+ )
356
+ if self.print_idx == 0:
357
+ logger.debug(
358
+ f"self_attn_block_embs, num={len(self_attn_block_embs)}, after ,shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}"
359
+ )
360
+
361
+ if self.proj_out is None:
362
+ return hidden_states
363
+
364
+ # 3. Output
365
+ if self.is_input_continuous:
366
+ if not self.use_linear_projection:
367
+ hidden_states = (
368
+ hidden_states.reshape(batch, height, width, inner_dim)
369
+ .permute(0, 3, 1, 2)
370
+ .contiguous()
371
+ )
372
+ hidden_states = (
373
+ self.proj_out(hidden_states, scale=lora_scale)
374
+ if not USE_PEFT_BACKEND
375
+ else self.proj_out(hidden_states)
376
+ )
377
+ else:
378
+ hidden_states = (
379
+ self.proj_out(hidden_states, scale=lora_scale)
380
+ if not USE_PEFT_BACKEND
381
+ else self.proj_out(hidden_states)
382
+ )
383
+ hidden_states = (
384
+ hidden_states.reshape(batch, height, width, inner_dim)
385
+ .permute(0, 3, 1, 2)
386
+ .contiguous()
387
+ )
388
+
389
+ output = hidden_states + residual
390
+ elif self.is_input_vectorized:
391
+ hidden_states = self.norm_out(hidden_states)
392
+ logits = self.out(hidden_states)
393
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
394
+ logits = logits.permute(0, 2, 1)
395
+
396
+ # log(p(x_0))
397
+ output = F.log_softmax(logits.double(), dim=1).float()
398
+
399
+ if self.is_input_patches:
400
+ if self.config.norm_type != "ada_norm_single":
401
+ conditioning = self.transformer_blocks[0].norm1.emb(
402
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
403
+ )
404
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
405
+ hidden_states = (
406
+ self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
407
+ )
408
+ hidden_states = self.proj_out_2(hidden_states)
409
+ elif self.config.norm_type == "ada_norm_single":
410
+ shift, scale = (
411
+ self.scale_shift_table[None] + embedded_timestep[:, None]
412
+ ).chunk(2, dim=1)
413
+ hidden_states = self.norm_out(hidden_states)
414
+ # Modulation
415
+ hidden_states = hidden_states * (1 + scale) + shift
416
+ hidden_states = self.proj_out(hidden_states)
417
+ hidden_states = hidden_states.squeeze(1)
418
+
419
+ # unpatchify
420
+ if self.adaln_single is None:
421
+ height = width = int(hidden_states.shape[1] ** 0.5)
422
+ hidden_states = hidden_states.reshape(
423
+ shape=(
424
+ -1,
425
+ height,
426
+ width,
427
+ self.patch_size,
428
+ self.patch_size,
429
+ self.out_channels,
430
+ )
431
+ )
432
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
433
+ output = hidden_states.reshape(
434
+ shape=(
435
+ -1,
436
+ self.out_channels,
437
+ height * self.patch_size,
438
+ width * self.patch_size,
439
+ )
440
+ )
441
+ self.print_idx += 1
442
+ if not return_dict:
443
+ return (output,)
444
+
445
+ return Transformer2DModelOutput(sample=output)
musev/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Literal, Optional, Tuple, Union, List
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ from diffusers.utils.torch_utils import apply_freeu
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention_processor import (
25
+ Attention,
26
+ AttnAddedKVProcessor,
27
+ AttnAddedKVProcessor2_0,
28
+ )
29
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
30
+ from diffusers.models.normalization import AdaGroupNorm
31
+ from diffusers.models.resnet import (
32
+ Downsample2D,
33
+ FirDownsample2D,
34
+ FirUpsample2D,
35
+ KDownsample2D,
36
+ KUpsample2D,
37
+ ResnetBlock2D,
38
+ Upsample2D,
39
+ )
40
+ from diffusers.models.unet_2d_blocks import (
41
+ AttnDownBlock2D,
42
+ AttnDownEncoderBlock2D,
43
+ AttnSkipDownBlock2D,
44
+ AttnSkipUpBlock2D,
45
+ AttnUpBlock2D,
46
+ AttnUpDecoderBlock2D,
47
+ DownEncoderBlock2D,
48
+ KCrossAttnDownBlock2D,
49
+ KCrossAttnUpBlock2D,
50
+ KDownBlock2D,
51
+ KUpBlock2D,
52
+ ResnetDownsampleBlock2D,
53
+ ResnetUpsampleBlock2D,
54
+ SimpleCrossAttnDownBlock2D,
55
+ SimpleCrossAttnUpBlock2D,
56
+ SkipDownBlock2D,
57
+ SkipUpBlock2D,
58
+ UpDecoderBlock2D,
59
+ )
60
+
61
+ from .transformer_2d import Transformer2DModel
62
+
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+
67
+ def get_down_block(
68
+ down_block_type: str,
69
+ num_layers: int,
70
+ in_channels: int,
71
+ out_channels: int,
72
+ temb_channels: int,
73
+ add_downsample: bool,
74
+ resnet_eps: float,
75
+ resnet_act_fn: str,
76
+ transformer_layers_per_block: int = 1,
77
+ num_attention_heads: Optional[int] = None,
78
+ resnet_groups: Optional[int] = None,
79
+ cross_attention_dim: Optional[int] = None,
80
+ downsample_padding: Optional[int] = None,
81
+ dual_cross_attention: bool = False,
82
+ use_linear_projection: bool = False,
83
+ only_cross_attention: bool = False,
84
+ upcast_attention: bool = False,
85
+ resnet_time_scale_shift: str = "default",
86
+ attention_type: str = "default",
87
+ resnet_skip_time_act: bool = False,
88
+ resnet_out_scale_factor: float = 1.0,
89
+ cross_attention_norm: Optional[str] = None,
90
+ attention_head_dim: Optional[int] = None,
91
+ downsample_type: Optional[str] = None,
92
+ dropout: float = 0.0,
93
+ ):
94
+ # If attn head dim is not defined, we default it to the number of heads
95
+ if attention_head_dim is None:
96
+ logger.warn(
97
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
98
+ )
99
+ attention_head_dim = num_attention_heads
100
+
101
+ down_block_type = (
102
+ down_block_type[7:]
103
+ if down_block_type.startswith("UNetRes")
104
+ else down_block_type
105
+ )
106
+ if down_block_type == "DownBlock2D":
107
+ return DownBlock2D(
108
+ num_layers=num_layers,
109
+ in_channels=in_channels,
110
+ out_channels=out_channels,
111
+ temb_channels=temb_channels,
112
+ dropout=dropout,
113
+ add_downsample=add_downsample,
114
+ resnet_eps=resnet_eps,
115
+ resnet_act_fn=resnet_act_fn,
116
+ resnet_groups=resnet_groups,
117
+ downsample_padding=downsample_padding,
118
+ resnet_time_scale_shift=resnet_time_scale_shift,
119
+ )
120
+ elif down_block_type == "ResnetDownsampleBlock2D":
121
+ return ResnetDownsampleBlock2D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ temb_channels=temb_channels,
126
+ dropout=dropout,
127
+ add_downsample=add_downsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ resnet_time_scale_shift=resnet_time_scale_shift,
132
+ skip_time_act=resnet_skip_time_act,
133
+ output_scale_factor=resnet_out_scale_factor,
134
+ )
135
+ elif down_block_type == "AttnDownBlock2D":
136
+ if add_downsample is False:
137
+ downsample_type = None
138
+ else:
139
+ downsample_type = downsample_type or "conv" # default to 'conv'
140
+ return AttnDownBlock2D(
141
+ num_layers=num_layers,
142
+ in_channels=in_channels,
143
+ out_channels=out_channels,
144
+ temb_channels=temb_channels,
145
+ dropout=dropout,
146
+ resnet_eps=resnet_eps,
147
+ resnet_act_fn=resnet_act_fn,
148
+ resnet_groups=resnet_groups,
149
+ downsample_padding=downsample_padding,
150
+ attention_head_dim=attention_head_dim,
151
+ resnet_time_scale_shift=resnet_time_scale_shift,
152
+ downsample_type=downsample_type,
153
+ )
154
+ elif down_block_type == "CrossAttnDownBlock2D":
155
+ if cross_attention_dim is None:
156
+ raise ValueError(
157
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
158
+ )
159
+ return CrossAttnDownBlock2D(
160
+ num_layers=num_layers,
161
+ transformer_layers_per_block=transformer_layers_per_block,
162
+ in_channels=in_channels,
163
+ out_channels=out_channels,
164
+ temb_channels=temb_channels,
165
+ dropout=dropout,
166
+ add_downsample=add_downsample,
167
+ resnet_eps=resnet_eps,
168
+ resnet_act_fn=resnet_act_fn,
169
+ resnet_groups=resnet_groups,
170
+ downsample_padding=downsample_padding,
171
+ cross_attention_dim=cross_attention_dim,
172
+ num_attention_heads=num_attention_heads,
173
+ dual_cross_attention=dual_cross_attention,
174
+ use_linear_projection=use_linear_projection,
175
+ only_cross_attention=only_cross_attention,
176
+ upcast_attention=upcast_attention,
177
+ resnet_time_scale_shift=resnet_time_scale_shift,
178
+ attention_type=attention_type,
179
+ )
180
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
181
+ if cross_attention_dim is None:
182
+ raise ValueError(
183
+ "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D"
184
+ )
185
+ return SimpleCrossAttnDownBlock2D(
186
+ num_layers=num_layers,
187
+ in_channels=in_channels,
188
+ out_channels=out_channels,
189
+ temb_channels=temb_channels,
190
+ dropout=dropout,
191
+ add_downsample=add_downsample,
192
+ resnet_eps=resnet_eps,
193
+ resnet_act_fn=resnet_act_fn,
194
+ resnet_groups=resnet_groups,
195
+ cross_attention_dim=cross_attention_dim,
196
+ attention_head_dim=attention_head_dim,
197
+ resnet_time_scale_shift=resnet_time_scale_shift,
198
+ skip_time_act=resnet_skip_time_act,
199
+ output_scale_factor=resnet_out_scale_factor,
200
+ only_cross_attention=only_cross_attention,
201
+ cross_attention_norm=cross_attention_norm,
202
+ )
203
+ elif down_block_type == "SkipDownBlock2D":
204
+ return SkipDownBlock2D(
205
+ num_layers=num_layers,
206
+ in_channels=in_channels,
207
+ out_channels=out_channels,
208
+ temb_channels=temb_channels,
209
+ dropout=dropout,
210
+ add_downsample=add_downsample,
211
+ resnet_eps=resnet_eps,
212
+ resnet_act_fn=resnet_act_fn,
213
+ downsample_padding=downsample_padding,
214
+ resnet_time_scale_shift=resnet_time_scale_shift,
215
+ )
216
+ elif down_block_type == "AttnSkipDownBlock2D":
217
+ return AttnSkipDownBlock2D(
218
+ num_layers=num_layers,
219
+ in_channels=in_channels,
220
+ out_channels=out_channels,
221
+ temb_channels=temb_channels,
222
+ dropout=dropout,
223
+ add_downsample=add_downsample,
224
+ resnet_eps=resnet_eps,
225
+ resnet_act_fn=resnet_act_fn,
226
+ attention_head_dim=attention_head_dim,
227
+ resnet_time_scale_shift=resnet_time_scale_shift,
228
+ )
229
+ elif down_block_type == "DownEncoderBlock2D":
230
+ return DownEncoderBlock2D(
231
+ num_layers=num_layers,
232
+ in_channels=in_channels,
233
+ out_channels=out_channels,
234
+ dropout=dropout,
235
+ add_downsample=add_downsample,
236
+ resnet_eps=resnet_eps,
237
+ resnet_act_fn=resnet_act_fn,
238
+ resnet_groups=resnet_groups,
239
+ downsample_padding=downsample_padding,
240
+ resnet_time_scale_shift=resnet_time_scale_shift,
241
+ )
242
+ elif down_block_type == "AttnDownEncoderBlock2D":
243
+ return AttnDownEncoderBlock2D(
244
+ num_layers=num_layers,
245
+ in_channels=in_channels,
246
+ out_channels=out_channels,
247
+ dropout=dropout,
248
+ add_downsample=add_downsample,
249
+ resnet_eps=resnet_eps,
250
+ resnet_act_fn=resnet_act_fn,
251
+ resnet_groups=resnet_groups,
252
+ downsample_padding=downsample_padding,
253
+ attention_head_dim=attention_head_dim,
254
+ resnet_time_scale_shift=resnet_time_scale_shift,
255
+ )
256
+ elif down_block_type == "KDownBlock2D":
257
+ return KDownBlock2D(
258
+ num_layers=num_layers,
259
+ in_channels=in_channels,
260
+ out_channels=out_channels,
261
+ temb_channels=temb_channels,
262
+ dropout=dropout,
263
+ add_downsample=add_downsample,
264
+ resnet_eps=resnet_eps,
265
+ resnet_act_fn=resnet_act_fn,
266
+ )
267
+ elif down_block_type == "KCrossAttnDownBlock2D":
268
+ return KCrossAttnDownBlock2D(
269
+ num_layers=num_layers,
270
+ in_channels=in_channels,
271
+ out_channels=out_channels,
272
+ temb_channels=temb_channels,
273
+ dropout=dropout,
274
+ add_downsample=add_downsample,
275
+ resnet_eps=resnet_eps,
276
+ resnet_act_fn=resnet_act_fn,
277
+ cross_attention_dim=cross_attention_dim,
278
+ attention_head_dim=attention_head_dim,
279
+ add_self_attention=True if not add_downsample else False,
280
+ )
281
+ raise ValueError(f"{down_block_type} does not exist.")
282
+
283
+
284
+ def get_up_block(
285
+ up_block_type: str,
286
+ num_layers: int,
287
+ in_channels: int,
288
+ out_channels: int,
289
+ prev_output_channel: int,
290
+ temb_channels: int,
291
+ add_upsample: bool,
292
+ resnet_eps: float,
293
+ resnet_act_fn: str,
294
+ resolution_idx: Optional[int] = None,
295
+ transformer_layers_per_block: int = 1,
296
+ num_attention_heads: Optional[int] = None,
297
+ resnet_groups: Optional[int] = None,
298
+ cross_attention_dim: Optional[int] = None,
299
+ dual_cross_attention: bool = False,
300
+ use_linear_projection: bool = False,
301
+ only_cross_attention: bool = False,
302
+ upcast_attention: bool = False,
303
+ resnet_time_scale_shift: str = "default",
304
+ attention_type: str = "default",
305
+ resnet_skip_time_act: bool = False,
306
+ resnet_out_scale_factor: float = 1.0,
307
+ cross_attention_norm: Optional[str] = None,
308
+ attention_head_dim: Optional[int] = None,
309
+ upsample_type: Optional[str] = None,
310
+ dropout: float = 0.0,
311
+ ) -> nn.Module:
312
+ # If attn head dim is not defined, we default it to the number of heads
313
+ if attention_head_dim is None:
314
+ logger.warn(
315
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
316
+ )
317
+ attention_head_dim = num_attention_heads
318
+
319
+ up_block_type = (
320
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
321
+ )
322
+ if up_block_type == "UpBlock2D":
323
+ return UpBlock2D(
324
+ num_layers=num_layers,
325
+ in_channels=in_channels,
326
+ out_channels=out_channels,
327
+ prev_output_channel=prev_output_channel,
328
+ temb_channels=temb_channels,
329
+ resolution_idx=resolution_idx,
330
+ dropout=dropout,
331
+ add_upsample=add_upsample,
332
+ resnet_eps=resnet_eps,
333
+ resnet_act_fn=resnet_act_fn,
334
+ resnet_groups=resnet_groups,
335
+ resnet_time_scale_shift=resnet_time_scale_shift,
336
+ )
337
+ elif up_block_type == "ResnetUpsampleBlock2D":
338
+ return ResnetUpsampleBlock2D(
339
+ num_layers=num_layers,
340
+ in_channels=in_channels,
341
+ out_channels=out_channels,
342
+ prev_output_channel=prev_output_channel,
343
+ temb_channels=temb_channels,
344
+ resolution_idx=resolution_idx,
345
+ dropout=dropout,
346
+ add_upsample=add_upsample,
347
+ resnet_eps=resnet_eps,
348
+ resnet_act_fn=resnet_act_fn,
349
+ resnet_groups=resnet_groups,
350
+ resnet_time_scale_shift=resnet_time_scale_shift,
351
+ skip_time_act=resnet_skip_time_act,
352
+ output_scale_factor=resnet_out_scale_factor,
353
+ )
354
+ elif up_block_type == "CrossAttnUpBlock2D":
355
+ if cross_attention_dim is None:
356
+ raise ValueError(
357
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
358
+ )
359
+ return CrossAttnUpBlock2D(
360
+ num_layers=num_layers,
361
+ transformer_layers_per_block=transformer_layers_per_block,
362
+ in_channels=in_channels,
363
+ out_channels=out_channels,
364
+ prev_output_channel=prev_output_channel,
365
+ temb_channels=temb_channels,
366
+ resolution_idx=resolution_idx,
367
+ dropout=dropout,
368
+ add_upsample=add_upsample,
369
+ resnet_eps=resnet_eps,
370
+ resnet_act_fn=resnet_act_fn,
371
+ resnet_groups=resnet_groups,
372
+ cross_attention_dim=cross_attention_dim,
373
+ num_attention_heads=num_attention_heads,
374
+ dual_cross_attention=dual_cross_attention,
375
+ use_linear_projection=use_linear_projection,
376
+ only_cross_attention=only_cross_attention,
377
+ upcast_attention=upcast_attention,
378
+ resnet_time_scale_shift=resnet_time_scale_shift,
379
+ attention_type=attention_type,
380
+ )
381
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
382
+ if cross_attention_dim is None:
383
+ raise ValueError(
384
+ "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D"
385
+ )
386
+ return SimpleCrossAttnUpBlock2D(
387
+ num_layers=num_layers,
388
+ in_channels=in_channels,
389
+ out_channels=out_channels,
390
+ prev_output_channel=prev_output_channel,
391
+ temb_channels=temb_channels,
392
+ resolution_idx=resolution_idx,
393
+ dropout=dropout,
394
+ add_upsample=add_upsample,
395
+ resnet_eps=resnet_eps,
396
+ resnet_act_fn=resnet_act_fn,
397
+ resnet_groups=resnet_groups,
398
+ cross_attention_dim=cross_attention_dim,
399
+ attention_head_dim=attention_head_dim,
400
+ resnet_time_scale_shift=resnet_time_scale_shift,
401
+ skip_time_act=resnet_skip_time_act,
402
+ output_scale_factor=resnet_out_scale_factor,
403
+ only_cross_attention=only_cross_attention,
404
+ cross_attention_norm=cross_attention_norm,
405
+ )
406
+ elif up_block_type == "AttnUpBlock2D":
407
+ if add_upsample is False:
408
+ upsample_type = None
409
+ else:
410
+ upsample_type = upsample_type or "conv" # default to 'conv'
411
+
412
+ return AttnUpBlock2D(
413
+ num_layers=num_layers,
414
+ in_channels=in_channels,
415
+ out_channels=out_channels,
416
+ prev_output_channel=prev_output_channel,
417
+ temb_channels=temb_channels,
418
+ resolution_idx=resolution_idx,
419
+ dropout=dropout,
420
+ resnet_eps=resnet_eps,
421
+ resnet_act_fn=resnet_act_fn,
422
+ resnet_groups=resnet_groups,
423
+ attention_head_dim=attention_head_dim,
424
+ resnet_time_scale_shift=resnet_time_scale_shift,
425
+ upsample_type=upsample_type,
426
+ )
427
+ elif up_block_type == "SkipUpBlock2D":
428
+ return SkipUpBlock2D(
429
+ num_layers=num_layers,
430
+ in_channels=in_channels,
431
+ out_channels=out_channels,
432
+ prev_output_channel=prev_output_channel,
433
+ temb_channels=temb_channels,
434
+ resolution_idx=resolution_idx,
435
+ dropout=dropout,
436
+ add_upsample=add_upsample,
437
+ resnet_eps=resnet_eps,
438
+ resnet_act_fn=resnet_act_fn,
439
+ resnet_time_scale_shift=resnet_time_scale_shift,
440
+ )
441
+ elif up_block_type == "AttnSkipUpBlock2D":
442
+ return AttnSkipUpBlock2D(
443
+ num_layers=num_layers,
444
+ in_channels=in_channels,
445
+ out_channels=out_channels,
446
+ prev_output_channel=prev_output_channel,
447
+ temb_channels=temb_channels,
448
+ resolution_idx=resolution_idx,
449
+ dropout=dropout,
450
+ add_upsample=add_upsample,
451
+ resnet_eps=resnet_eps,
452
+ resnet_act_fn=resnet_act_fn,
453
+ attention_head_dim=attention_head_dim,
454
+ resnet_time_scale_shift=resnet_time_scale_shift,
455
+ )
456
+ elif up_block_type == "UpDecoderBlock2D":
457
+ return UpDecoderBlock2D(
458
+ num_layers=num_layers,
459
+ in_channels=in_channels,
460
+ out_channels=out_channels,
461
+ resolution_idx=resolution_idx,
462
+ dropout=dropout,
463
+ add_upsample=add_upsample,
464
+ resnet_eps=resnet_eps,
465
+ resnet_act_fn=resnet_act_fn,
466
+ resnet_groups=resnet_groups,
467
+ resnet_time_scale_shift=resnet_time_scale_shift,
468
+ temb_channels=temb_channels,
469
+ )
470
+ elif up_block_type == "AttnUpDecoderBlock2D":
471
+ return AttnUpDecoderBlock2D(
472
+ num_layers=num_layers,
473
+ in_channels=in_channels,
474
+ out_channels=out_channels,
475
+ resolution_idx=resolution_idx,
476
+ dropout=dropout,
477
+ add_upsample=add_upsample,
478
+ resnet_eps=resnet_eps,
479
+ resnet_act_fn=resnet_act_fn,
480
+ resnet_groups=resnet_groups,
481
+ attention_head_dim=attention_head_dim,
482
+ resnet_time_scale_shift=resnet_time_scale_shift,
483
+ temb_channels=temb_channels,
484
+ )
485
+ elif up_block_type == "KUpBlock2D":
486
+ return KUpBlock2D(
487
+ num_layers=num_layers,
488
+ in_channels=in_channels,
489
+ out_channels=out_channels,
490
+ temb_channels=temb_channels,
491
+ resolution_idx=resolution_idx,
492
+ dropout=dropout,
493
+ add_upsample=add_upsample,
494
+ resnet_eps=resnet_eps,
495
+ resnet_act_fn=resnet_act_fn,
496
+ )
497
+ elif up_block_type == "KCrossAttnUpBlock2D":
498
+ return KCrossAttnUpBlock2D(
499
+ num_layers=num_layers,
500
+ in_channels=in_channels,
501
+ out_channels=out_channels,
502
+ temb_channels=temb_channels,
503
+ resolution_idx=resolution_idx,
504
+ dropout=dropout,
505
+ add_upsample=add_upsample,
506
+ resnet_eps=resnet_eps,
507
+ resnet_act_fn=resnet_act_fn,
508
+ cross_attention_dim=cross_attention_dim,
509
+ attention_head_dim=attention_head_dim,
510
+ )
511
+
512
+ raise ValueError(f"{up_block_type} does not exist.")
513
+
514
+
515
+ class UNetMidBlock2D(nn.Module):
516
+ """
517
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
518
+
519
+ Args:
520
+ in_channels (`int`): The number of input channels.
521
+ temb_channels (`int`): The number of temporal embedding channels.
522
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
523
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
524
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
525
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
526
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
527
+ model on tasks with long-range temporal dependencies.
528
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
529
+ resnet_groups (`int`, *optional*, defaults to 32):
530
+ The number of groups to use in the group normalization layers of the resnet blocks.
531
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
532
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
533
+ Whether to use pre-normalization for the resnet blocks.
534
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
535
+ attention_head_dim (`int`, *optional*, defaults to 1):
536
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
537
+ the number of input channels.
538
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
539
+
540
+ Returns:
541
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
542
+ in_channels, height, width)`.
543
+
544
+ """
545
+
546
+ def __init__(
547
+ self,
548
+ in_channels: int,
549
+ temb_channels: int,
550
+ dropout: float = 0.0,
551
+ num_layers: int = 1,
552
+ resnet_eps: float = 1e-6,
553
+ resnet_time_scale_shift: str = "default", # default, spatial
554
+ resnet_act_fn: str = "swish",
555
+ resnet_groups: int = 32,
556
+ attn_groups: Optional[int] = None,
557
+ resnet_pre_norm: bool = True,
558
+ add_attention: bool = True,
559
+ attention_head_dim: int = 1,
560
+ output_scale_factor: float = 1.0,
561
+ ):
562
+ super().__init__()
563
+ resnet_groups = (
564
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
565
+ )
566
+ self.add_attention = add_attention
567
+
568
+ if attn_groups is None:
569
+ attn_groups = (
570
+ resnet_groups if resnet_time_scale_shift == "default" else None
571
+ )
572
+
573
+ # there is always at least one resnet
574
+ resnets = [
575
+ ResnetBlock2D(
576
+ in_channels=in_channels,
577
+ out_channels=in_channels,
578
+ temb_channels=temb_channels,
579
+ eps=resnet_eps,
580
+ groups=resnet_groups,
581
+ dropout=dropout,
582
+ time_embedding_norm=resnet_time_scale_shift,
583
+ non_linearity=resnet_act_fn,
584
+ output_scale_factor=output_scale_factor,
585
+ pre_norm=resnet_pre_norm,
586
+ )
587
+ ]
588
+ attentions = []
589
+
590
+ if attention_head_dim is None:
591
+ logger.warn(
592
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
593
+ )
594
+ attention_head_dim = in_channels
595
+
596
+ for _ in range(num_layers):
597
+ if self.add_attention:
598
+ attentions.append(
599
+ Attention(
600
+ in_channels,
601
+ heads=in_channels // attention_head_dim,
602
+ dim_head=attention_head_dim,
603
+ rescale_output_factor=output_scale_factor,
604
+ eps=resnet_eps,
605
+ norm_num_groups=attn_groups,
606
+ spatial_norm_dim=temb_channels
607
+ if resnet_time_scale_shift == "spatial"
608
+ else None,
609
+ residual_connection=True,
610
+ bias=True,
611
+ upcast_softmax=True,
612
+ _from_deprecated_attn_block=True,
613
+ )
614
+ )
615
+ else:
616
+ attentions.append(None)
617
+
618
+ resnets.append(
619
+ ResnetBlock2D(
620
+ in_channels=in_channels,
621
+ out_channels=in_channels,
622
+ temb_channels=temb_channels,
623
+ eps=resnet_eps,
624
+ groups=resnet_groups,
625
+ dropout=dropout,
626
+ time_embedding_norm=resnet_time_scale_shift,
627
+ non_linearity=resnet_act_fn,
628
+ output_scale_factor=output_scale_factor,
629
+ pre_norm=resnet_pre_norm,
630
+ )
631
+ )
632
+
633
+ self.attentions = nn.ModuleList(attentions)
634
+ self.resnets = nn.ModuleList(resnets)
635
+
636
+ def forward(
637
+ self,
638
+ hidden_states: torch.FloatTensor,
639
+ temb: Optional[torch.FloatTensor] = None,
640
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
641
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
642
+ ) -> torch.FloatTensor:
643
+ hidden_states = self.resnets[0](hidden_states, temb)
644
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
645
+ if attn is not None:
646
+ hidden_states = attn(
647
+ hidden_states,
648
+ temb=temb,
649
+ self_attn_block_embs=self_attn_block_embs,
650
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
651
+ )
652
+ hidden_states = resnet(hidden_states, temb)
653
+
654
+ return hidden_states
655
+
656
+
657
+ class UNetMidBlock2DCrossAttn(nn.Module):
658
+ def __init__(
659
+ self,
660
+ in_channels: int,
661
+ temb_channels: int,
662
+ dropout: float = 0.0,
663
+ num_layers: int = 1,
664
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
665
+ resnet_eps: float = 1e-6,
666
+ resnet_time_scale_shift: str = "default",
667
+ resnet_act_fn: str = "swish",
668
+ resnet_groups: int = 32,
669
+ resnet_pre_norm: bool = True,
670
+ num_attention_heads: int = 1,
671
+ output_scale_factor: float = 1.0,
672
+ cross_attention_dim: int = 1280,
673
+ dual_cross_attention: bool = False,
674
+ use_linear_projection: bool = False,
675
+ upcast_attention: bool = False,
676
+ attention_type: str = "default",
677
+ ):
678
+ super().__init__()
679
+
680
+ self.has_cross_attention = True
681
+ self.num_attention_heads = num_attention_heads
682
+ resnet_groups = (
683
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
684
+ )
685
+
686
+ # support for variable transformer layers per block
687
+ if isinstance(transformer_layers_per_block, int):
688
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
689
+
690
+ # there is always at least one resnet
691
+ resnets = [
692
+ ResnetBlock2D(
693
+ in_channels=in_channels,
694
+ out_channels=in_channels,
695
+ temb_channels=temb_channels,
696
+ eps=resnet_eps,
697
+ groups=resnet_groups,
698
+ dropout=dropout,
699
+ time_embedding_norm=resnet_time_scale_shift,
700
+ non_linearity=resnet_act_fn,
701
+ output_scale_factor=output_scale_factor,
702
+ pre_norm=resnet_pre_norm,
703
+ )
704
+ ]
705
+ attentions = []
706
+
707
+ for i in range(num_layers):
708
+ if not dual_cross_attention:
709
+ attentions.append(
710
+ Transformer2DModel(
711
+ num_attention_heads,
712
+ in_channels // num_attention_heads,
713
+ in_channels=in_channels,
714
+ num_layers=transformer_layers_per_block[i],
715
+ cross_attention_dim=cross_attention_dim,
716
+ norm_num_groups=resnet_groups,
717
+ use_linear_projection=use_linear_projection,
718
+ upcast_attention=upcast_attention,
719
+ attention_type=attention_type,
720
+ )
721
+ )
722
+ else:
723
+ attentions.append(
724
+ DualTransformer2DModel(
725
+ num_attention_heads,
726
+ in_channels // num_attention_heads,
727
+ in_channels=in_channels,
728
+ num_layers=1,
729
+ cross_attention_dim=cross_attention_dim,
730
+ norm_num_groups=resnet_groups,
731
+ )
732
+ )
733
+ resnets.append(
734
+ ResnetBlock2D(
735
+ in_channels=in_channels,
736
+ out_channels=in_channels,
737
+ temb_channels=temb_channels,
738
+ eps=resnet_eps,
739
+ groups=resnet_groups,
740
+ dropout=dropout,
741
+ time_embedding_norm=resnet_time_scale_shift,
742
+ non_linearity=resnet_act_fn,
743
+ output_scale_factor=output_scale_factor,
744
+ pre_norm=resnet_pre_norm,
745
+ )
746
+ )
747
+
748
+ self.attentions = nn.ModuleList(attentions)
749
+ self.resnets = nn.ModuleList(resnets)
750
+
751
+ self.gradient_checkpointing = False
752
+
753
+ def forward(
754
+ self,
755
+ hidden_states: torch.FloatTensor,
756
+ temb: Optional[torch.FloatTensor] = None,
757
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
758
+ attention_mask: Optional[torch.FloatTensor] = None,
759
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
760
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
761
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
762
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
763
+ ) -> torch.FloatTensor:
764
+ lora_scale = (
765
+ cross_attention_kwargs.get("scale", 1.0)
766
+ if cross_attention_kwargs is not None
767
+ else 1.0
768
+ )
769
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
770
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
771
+ if self.training and self.gradient_checkpointing:
772
+
773
+ def create_custom_forward(module, return_dict=None):
774
+ def custom_forward(*inputs):
775
+ if return_dict is not None:
776
+ return module(*inputs, return_dict=return_dict)
777
+ else:
778
+ return module(*inputs)
779
+
780
+ return custom_forward
781
+
782
+ ckpt_kwargs: Dict[str, Any] = (
783
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
784
+ )
785
+ hidden_states = attn(
786
+ hidden_states,
787
+ encoder_hidden_states=encoder_hidden_states,
788
+ cross_attention_kwargs=cross_attention_kwargs,
789
+ attention_mask=attention_mask,
790
+ encoder_attention_mask=encoder_attention_mask,
791
+ return_dict=False,
792
+ self_attn_block_embs=self_attn_block_embs,
793
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
794
+ )[0]
795
+ hidden_states = torch.utils.checkpoint.checkpoint(
796
+ create_custom_forward(resnet),
797
+ hidden_states,
798
+ temb,
799
+ **ckpt_kwargs,
800
+ )
801
+ else:
802
+ hidden_states = attn(
803
+ hidden_states,
804
+ encoder_hidden_states=encoder_hidden_states,
805
+ cross_attention_kwargs=cross_attention_kwargs,
806
+ attention_mask=attention_mask,
807
+ encoder_attention_mask=encoder_attention_mask,
808
+ return_dict=False,
809
+ self_attn_block_embs=self_attn_block_embs,
810
+ )[0]
811
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
812
+
813
+ return hidden_states
814
+
815
+
816
+ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
817
+ def __init__(
818
+ self,
819
+ in_channels: int,
820
+ temb_channels: int,
821
+ dropout: float = 0.0,
822
+ num_layers: int = 1,
823
+ resnet_eps: float = 1e-6,
824
+ resnet_time_scale_shift: str = "default",
825
+ resnet_act_fn: str = "swish",
826
+ resnet_groups: int = 32,
827
+ resnet_pre_norm: bool = True,
828
+ attention_head_dim: int = 1,
829
+ output_scale_factor: float = 1.0,
830
+ cross_attention_dim: int = 1280,
831
+ skip_time_act: bool = False,
832
+ only_cross_attention: bool = False,
833
+ cross_attention_norm: Optional[str] = None,
834
+ ):
835
+ super().__init__()
836
+
837
+ self.has_cross_attention = True
838
+
839
+ self.attention_head_dim = attention_head_dim
840
+ resnet_groups = (
841
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
842
+ )
843
+
844
+ self.num_heads = in_channels // self.attention_head_dim
845
+
846
+ # there is always at least one resnet
847
+ resnets = [
848
+ ResnetBlock2D(
849
+ in_channels=in_channels,
850
+ out_channels=in_channels,
851
+ temb_channels=temb_channels,
852
+ eps=resnet_eps,
853
+ groups=resnet_groups,
854
+ dropout=dropout,
855
+ time_embedding_norm=resnet_time_scale_shift,
856
+ non_linearity=resnet_act_fn,
857
+ output_scale_factor=output_scale_factor,
858
+ pre_norm=resnet_pre_norm,
859
+ skip_time_act=skip_time_act,
860
+ )
861
+ ]
862
+ attentions = []
863
+
864
+ for _ in range(num_layers):
865
+ processor = (
866
+ AttnAddedKVProcessor2_0()
867
+ if hasattr(F, "scaled_dot_product_attention")
868
+ else AttnAddedKVProcessor()
869
+ )
870
+
871
+ attentions.append(
872
+ Attention(
873
+ query_dim=in_channels,
874
+ cross_attention_dim=in_channels,
875
+ heads=self.num_heads,
876
+ dim_head=self.attention_head_dim,
877
+ added_kv_proj_dim=cross_attention_dim,
878
+ norm_num_groups=resnet_groups,
879
+ bias=True,
880
+ upcast_softmax=True,
881
+ only_cross_attention=only_cross_attention,
882
+ cross_attention_norm=cross_attention_norm,
883
+ processor=processor,
884
+ )
885
+ )
886
+ resnets.append(
887
+ ResnetBlock2D(
888
+ in_channels=in_channels,
889
+ out_channels=in_channels,
890
+ temb_channels=temb_channels,
891
+ eps=resnet_eps,
892
+ groups=resnet_groups,
893
+ dropout=dropout,
894
+ time_embedding_norm=resnet_time_scale_shift,
895
+ non_linearity=resnet_act_fn,
896
+ output_scale_factor=output_scale_factor,
897
+ pre_norm=resnet_pre_norm,
898
+ skip_time_act=skip_time_act,
899
+ )
900
+ )
901
+
902
+ self.attentions = nn.ModuleList(attentions)
903
+ self.resnets = nn.ModuleList(resnets)
904
+
905
+ def forward(
906
+ self,
907
+ hidden_states: torch.FloatTensor,
908
+ temb: Optional[torch.FloatTensor] = None,
909
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
910
+ attention_mask: Optional[torch.FloatTensor] = None,
911
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
912
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
913
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
914
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
915
+ ) -> torch.FloatTensor:
916
+ cross_attention_kwargs = (
917
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
918
+ )
919
+ lora_scale = cross_attention_kwargs.get("scale", 1.0)
920
+
921
+ if attention_mask is None:
922
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
923
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
924
+ else:
925
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
926
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
927
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
928
+ # then we can simplify this whole if/else block to:
929
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
930
+ mask = attention_mask
931
+
932
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
933
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
934
+ # attn
935
+ hidden_states = attn(
936
+ hidden_states,
937
+ encoder_hidden_states=encoder_hidden_states,
938
+ attention_mask=mask,
939
+ **cross_attention_kwargs,
940
+ self_attn_block_embs=self_attn_block_embs,
941
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
942
+ )
943
+
944
+ # resnet
945
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
946
+
947
+ return hidden_states
948
+
949
+
950
+ class CrossAttnDownBlock2D(nn.Module):
951
+ print_idx = 0
952
+
953
+ def __init__(
954
+ self,
955
+ in_channels: int,
956
+ out_channels: int,
957
+ temb_channels: int,
958
+ dropout: float = 0.0,
959
+ num_layers: int = 1,
960
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
961
+ resnet_eps: float = 1e-6,
962
+ resnet_time_scale_shift: str = "default",
963
+ resnet_act_fn: str = "swish",
964
+ resnet_groups: int = 32,
965
+ resnet_pre_norm: bool = True,
966
+ num_attention_heads: int = 1,
967
+ cross_attention_dim: int = 1280,
968
+ output_scale_factor: float = 1.0,
969
+ downsample_padding: int = 1,
970
+ add_downsample: bool = True,
971
+ dual_cross_attention: bool = False,
972
+ use_linear_projection: bool = False,
973
+ only_cross_attention: bool = False,
974
+ upcast_attention: bool = False,
975
+ attention_type: str = "default",
976
+ ):
977
+ super().__init__()
978
+ resnets = []
979
+ attentions = []
980
+
981
+ self.has_cross_attention = True
982
+ self.num_attention_heads = num_attention_heads
983
+ if isinstance(transformer_layers_per_block, int):
984
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
985
+
986
+ for i in range(num_layers):
987
+ in_channels = in_channels if i == 0 else out_channels
988
+ resnets.append(
989
+ ResnetBlock2D(
990
+ in_channels=in_channels,
991
+ out_channels=out_channels,
992
+ temb_channels=temb_channels,
993
+ eps=resnet_eps,
994
+ groups=resnet_groups,
995
+ dropout=dropout,
996
+ time_embedding_norm=resnet_time_scale_shift,
997
+ non_linearity=resnet_act_fn,
998
+ output_scale_factor=output_scale_factor,
999
+ pre_norm=resnet_pre_norm,
1000
+ )
1001
+ )
1002
+ if not dual_cross_attention:
1003
+ attentions.append(
1004
+ Transformer2DModel(
1005
+ num_attention_heads,
1006
+ out_channels // num_attention_heads,
1007
+ in_channels=out_channels,
1008
+ num_layers=transformer_layers_per_block[i],
1009
+ cross_attention_dim=cross_attention_dim,
1010
+ norm_num_groups=resnet_groups,
1011
+ use_linear_projection=use_linear_projection,
1012
+ only_cross_attention=only_cross_attention,
1013
+ upcast_attention=upcast_attention,
1014
+ attention_type=attention_type,
1015
+ )
1016
+ )
1017
+ else:
1018
+ attentions.append(
1019
+ DualTransformer2DModel(
1020
+ num_attention_heads,
1021
+ out_channels // num_attention_heads,
1022
+ in_channels=out_channels,
1023
+ num_layers=1,
1024
+ cross_attention_dim=cross_attention_dim,
1025
+ norm_num_groups=resnet_groups,
1026
+ )
1027
+ )
1028
+ self.attentions = nn.ModuleList(attentions)
1029
+ self.resnets = nn.ModuleList(resnets)
1030
+
1031
+ if add_downsample:
1032
+ self.downsamplers = nn.ModuleList(
1033
+ [
1034
+ Downsample2D(
1035
+ out_channels,
1036
+ use_conv=True,
1037
+ out_channels=out_channels,
1038
+ padding=downsample_padding,
1039
+ name="op",
1040
+ )
1041
+ ]
1042
+ )
1043
+ else:
1044
+ self.downsamplers = None
1045
+
1046
+ self.gradient_checkpointing = False
1047
+
1048
+ def forward(
1049
+ self,
1050
+ hidden_states: torch.FloatTensor,
1051
+ temb: Optional[torch.FloatTensor] = None,
1052
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1053
+ attention_mask: Optional[torch.FloatTensor] = None,
1054
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1055
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1056
+ additional_residuals: Optional[torch.FloatTensor] = None,
1057
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
1058
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
1059
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1060
+ output_states = ()
1061
+
1062
+ lora_scale = (
1063
+ cross_attention_kwargs.get("scale", 1.0)
1064
+ if cross_attention_kwargs is not None
1065
+ else 1.0
1066
+ )
1067
+
1068
+ blocks = list(zip(self.resnets, self.attentions))
1069
+
1070
+ for i, (resnet, attn) in enumerate(blocks):
1071
+ if self.training and self.gradient_checkpointing:
1072
+
1073
+ def create_custom_forward(module, return_dict=None):
1074
+ def custom_forward(*inputs):
1075
+ if return_dict is not None:
1076
+ return module(*inputs, return_dict=return_dict)
1077
+ else:
1078
+ return module(*inputs)
1079
+
1080
+ return custom_forward
1081
+
1082
+ ckpt_kwargs: Dict[str, Any] = (
1083
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1084
+ )
1085
+ hidden_states = torch.utils.checkpoint.checkpoint(
1086
+ create_custom_forward(resnet),
1087
+ hidden_states,
1088
+ temb,
1089
+ **ckpt_kwargs,
1090
+ )
1091
+ if self.print_idx == 0:
1092
+ logger.debug(f"unet3d after resnet {hidden_states.mean()}")
1093
+
1094
+ hidden_states = attn(
1095
+ hidden_states,
1096
+ encoder_hidden_states=encoder_hidden_states,
1097
+ cross_attention_kwargs=cross_attention_kwargs,
1098
+ attention_mask=attention_mask,
1099
+ encoder_attention_mask=encoder_attention_mask,
1100
+ return_dict=False,
1101
+ self_attn_block_embs=self_attn_block_embs,
1102
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
1103
+ )[0]
1104
+ else:
1105
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1106
+ if self.print_idx == 0:
1107
+ logger.debug(f"unet3d after resnet {hidden_states.mean()}")
1108
+ hidden_states = attn(
1109
+ hidden_states,
1110
+ encoder_hidden_states=encoder_hidden_states,
1111
+ cross_attention_kwargs=cross_attention_kwargs,
1112
+ attention_mask=attention_mask,
1113
+ encoder_attention_mask=encoder_attention_mask,
1114
+ return_dict=False,
1115
+ self_attn_block_embs=self_attn_block_embs,
1116
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
1117
+ )[0]
1118
+
1119
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
1120
+ if i == len(blocks) - 1 and additional_residuals is not None:
1121
+ hidden_states = hidden_states + additional_residuals
1122
+
1123
+ output_states = output_states + (hidden_states,)
1124
+
1125
+ if self.downsamplers is not None:
1126
+ for downsampler in self.downsamplers:
1127
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
1128
+
1129
+ output_states = output_states + (hidden_states,)
1130
+
1131
+ self.print_idx += 1
1132
+ return hidden_states, output_states
1133
+
1134
+
1135
+ class DownBlock2D(nn.Module):
1136
+ def __init__(
1137
+ self,
1138
+ in_channels: int,
1139
+ out_channels: int,
1140
+ temb_channels: int,
1141
+ dropout: float = 0.0,
1142
+ num_layers: int = 1,
1143
+ resnet_eps: float = 1e-6,
1144
+ resnet_time_scale_shift: str = "default",
1145
+ resnet_act_fn: str = "swish",
1146
+ resnet_groups: int = 32,
1147
+ resnet_pre_norm: bool = True,
1148
+ output_scale_factor: float = 1.0,
1149
+ add_downsample: bool = True,
1150
+ downsample_padding: int = 1,
1151
+ ):
1152
+ super().__init__()
1153
+ resnets = []
1154
+
1155
+ for i in range(num_layers):
1156
+ in_channels = in_channels if i == 0 else out_channels
1157
+ resnets.append(
1158
+ ResnetBlock2D(
1159
+ in_channels=in_channels,
1160
+ out_channels=out_channels,
1161
+ temb_channels=temb_channels,
1162
+ eps=resnet_eps,
1163
+ groups=resnet_groups,
1164
+ dropout=dropout,
1165
+ time_embedding_norm=resnet_time_scale_shift,
1166
+ non_linearity=resnet_act_fn,
1167
+ output_scale_factor=output_scale_factor,
1168
+ pre_norm=resnet_pre_norm,
1169
+ )
1170
+ )
1171
+
1172
+ self.resnets = nn.ModuleList(resnets)
1173
+
1174
+ if add_downsample:
1175
+ self.downsamplers = nn.ModuleList(
1176
+ [
1177
+ Downsample2D(
1178
+ out_channels,
1179
+ use_conv=True,
1180
+ out_channels=out_channels,
1181
+ padding=downsample_padding,
1182
+ name="op",
1183
+ )
1184
+ ]
1185
+ )
1186
+ else:
1187
+ self.downsamplers = None
1188
+
1189
+ self.gradient_checkpointing = False
1190
+
1191
+ def forward(
1192
+ self,
1193
+ hidden_states: torch.FloatTensor,
1194
+ temb: Optional[torch.FloatTensor] = None,
1195
+ scale: float = 1.0,
1196
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
1197
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
1198
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1199
+ output_states = ()
1200
+
1201
+ for resnet in self.resnets:
1202
+ if self.training and self.gradient_checkpointing:
1203
+
1204
+ def create_custom_forward(module):
1205
+ def custom_forward(*inputs):
1206
+ return module(*inputs)
1207
+
1208
+ return custom_forward
1209
+
1210
+ if is_torch_version(">=", "1.11.0"):
1211
+ hidden_states = torch.utils.checkpoint.checkpoint(
1212
+ create_custom_forward(resnet),
1213
+ hidden_states,
1214
+ temb,
1215
+ use_reentrant=False,
1216
+ )
1217
+ else:
1218
+ hidden_states = torch.utils.checkpoint.checkpoint(
1219
+ create_custom_forward(resnet), hidden_states, temb
1220
+ )
1221
+ else:
1222
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1223
+
1224
+ output_states = output_states + (hidden_states,)
1225
+
1226
+ if self.downsamplers is not None:
1227
+ for downsampler in self.downsamplers:
1228
+ hidden_states = downsampler(hidden_states, scale=scale)
1229
+
1230
+ output_states = output_states + (hidden_states,)
1231
+
1232
+ return hidden_states, output_states
1233
+
1234
+
1235
+ class CrossAttnUpBlock2D(nn.Module):
1236
+ def __init__(
1237
+ self,
1238
+ in_channels: int,
1239
+ out_channels: int,
1240
+ prev_output_channel: int,
1241
+ temb_channels: int,
1242
+ resolution_idx: Optional[int] = None,
1243
+ dropout: float = 0.0,
1244
+ num_layers: int = 1,
1245
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
1246
+ resnet_eps: float = 1e-6,
1247
+ resnet_time_scale_shift: str = "default",
1248
+ resnet_act_fn: str = "swish",
1249
+ resnet_groups: int = 32,
1250
+ resnet_pre_norm: bool = True,
1251
+ num_attention_heads: int = 1,
1252
+ cross_attention_dim: int = 1280,
1253
+ output_scale_factor: float = 1.0,
1254
+ add_upsample: bool = True,
1255
+ dual_cross_attention: bool = False,
1256
+ use_linear_projection: bool = False,
1257
+ only_cross_attention: bool = False,
1258
+ upcast_attention: bool = False,
1259
+ attention_type: str = "default",
1260
+ ):
1261
+ super().__init__()
1262
+ resnets = []
1263
+ attentions = []
1264
+
1265
+ self.has_cross_attention = True
1266
+ self.num_attention_heads = num_attention_heads
1267
+
1268
+ if isinstance(transformer_layers_per_block, int):
1269
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
1270
+
1271
+ for i in range(num_layers):
1272
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1273
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1274
+
1275
+ resnets.append(
1276
+ ResnetBlock2D(
1277
+ in_channels=resnet_in_channels + res_skip_channels,
1278
+ out_channels=out_channels,
1279
+ temb_channels=temb_channels,
1280
+ eps=resnet_eps,
1281
+ groups=resnet_groups,
1282
+ dropout=dropout,
1283
+ time_embedding_norm=resnet_time_scale_shift,
1284
+ non_linearity=resnet_act_fn,
1285
+ output_scale_factor=output_scale_factor,
1286
+ pre_norm=resnet_pre_norm,
1287
+ )
1288
+ )
1289
+ if not dual_cross_attention:
1290
+ attentions.append(
1291
+ Transformer2DModel(
1292
+ num_attention_heads,
1293
+ out_channels // num_attention_heads,
1294
+ in_channels=out_channels,
1295
+ num_layers=transformer_layers_per_block[i],
1296
+ cross_attention_dim=cross_attention_dim,
1297
+ norm_num_groups=resnet_groups,
1298
+ use_linear_projection=use_linear_projection,
1299
+ only_cross_attention=only_cross_attention,
1300
+ upcast_attention=upcast_attention,
1301
+ attention_type=attention_type,
1302
+ )
1303
+ )
1304
+ else:
1305
+ attentions.append(
1306
+ DualTransformer2DModel(
1307
+ num_attention_heads,
1308
+ out_channels // num_attention_heads,
1309
+ in_channels=out_channels,
1310
+ num_layers=1,
1311
+ cross_attention_dim=cross_attention_dim,
1312
+ norm_num_groups=resnet_groups,
1313
+ )
1314
+ )
1315
+ self.attentions = nn.ModuleList(attentions)
1316
+ self.resnets = nn.ModuleList(resnets)
1317
+
1318
+ if add_upsample:
1319
+ self.upsamplers = nn.ModuleList(
1320
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1321
+ )
1322
+ else:
1323
+ self.upsamplers = None
1324
+
1325
+ self.gradient_checkpointing = False
1326
+ self.resolution_idx = resolution_idx
1327
+
1328
+ def forward(
1329
+ self,
1330
+ hidden_states: torch.FloatTensor,
1331
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1332
+ temb: Optional[torch.FloatTensor] = None,
1333
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1334
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1335
+ upsample_size: Optional[int] = None,
1336
+ attention_mask: Optional[torch.FloatTensor] = None,
1337
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1338
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
1339
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
1340
+ ) -> torch.FloatTensor:
1341
+ lora_scale = (
1342
+ cross_attention_kwargs.get("scale", 1.0)
1343
+ if cross_attention_kwargs is not None
1344
+ else 1.0
1345
+ )
1346
+ is_freeu_enabled = (
1347
+ getattr(self, "s1", None)
1348
+ and getattr(self, "s2", None)
1349
+ and getattr(self, "b1", None)
1350
+ and getattr(self, "b2", None)
1351
+ )
1352
+
1353
+ for resnet, attn in zip(self.resnets, self.attentions):
1354
+ # pop res hidden states
1355
+ res_hidden_states = res_hidden_states_tuple[-1]
1356
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1357
+
1358
+ # FreeU: Only operate on the first two stages
1359
+ if is_freeu_enabled:
1360
+ hidden_states, res_hidden_states = apply_freeu(
1361
+ self.resolution_idx,
1362
+ hidden_states,
1363
+ res_hidden_states,
1364
+ s1=self.s1,
1365
+ s2=self.s2,
1366
+ b1=self.b1,
1367
+ b2=self.b2,
1368
+ )
1369
+
1370
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1371
+
1372
+ if self.training and self.gradient_checkpointing:
1373
+
1374
+ def create_custom_forward(module, return_dict=None):
1375
+ def custom_forward(*inputs):
1376
+ if return_dict is not None:
1377
+ return module(*inputs, return_dict=return_dict)
1378
+ else:
1379
+ return module(*inputs)
1380
+
1381
+ return custom_forward
1382
+
1383
+ ckpt_kwargs: Dict[str, Any] = (
1384
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1385
+ )
1386
+ hidden_states = torch.utils.checkpoint.checkpoint(
1387
+ create_custom_forward(resnet),
1388
+ hidden_states,
1389
+ temb,
1390
+ **ckpt_kwargs,
1391
+ )
1392
+ hidden_states = attn(
1393
+ hidden_states,
1394
+ encoder_hidden_states=encoder_hidden_states,
1395
+ cross_attention_kwargs=cross_attention_kwargs,
1396
+ attention_mask=attention_mask,
1397
+ encoder_attention_mask=encoder_attention_mask,
1398
+ return_dict=False,
1399
+ self_attn_block_embs=self_attn_block_embs,
1400
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
1401
+ )[0]
1402
+ else:
1403
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1404
+ hidden_states = attn(
1405
+ hidden_states,
1406
+ encoder_hidden_states=encoder_hidden_states,
1407
+ cross_attention_kwargs=cross_attention_kwargs,
1408
+ attention_mask=attention_mask,
1409
+ encoder_attention_mask=encoder_attention_mask,
1410
+ return_dict=False,
1411
+ self_attn_block_embs=self_attn_block_embs,
1412
+ )[0]
1413
+
1414
+ if self.upsamplers is not None:
1415
+ for upsampler in self.upsamplers:
1416
+ hidden_states = upsampler(
1417
+ hidden_states, upsample_size, scale=lora_scale
1418
+ )
1419
+
1420
+ return hidden_states
1421
+
1422
+
1423
+ class UpBlock2D(nn.Module):
1424
+ def __init__(
1425
+ self,
1426
+ in_channels: int,
1427
+ prev_output_channel: int,
1428
+ out_channels: int,
1429
+ temb_channels: int,
1430
+ resolution_idx: Optional[int] = None,
1431
+ dropout: float = 0.0,
1432
+ num_layers: int = 1,
1433
+ resnet_eps: float = 1e-6,
1434
+ resnet_time_scale_shift: str = "default",
1435
+ resnet_act_fn: str = "swish",
1436
+ resnet_groups: int = 32,
1437
+ resnet_pre_norm: bool = True,
1438
+ output_scale_factor: float = 1.0,
1439
+ add_upsample: bool = True,
1440
+ ):
1441
+ super().__init__()
1442
+ resnets = []
1443
+
1444
+ for i in range(num_layers):
1445
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1446
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1447
+
1448
+ resnets.append(
1449
+ ResnetBlock2D(
1450
+ in_channels=resnet_in_channels + res_skip_channels,
1451
+ out_channels=out_channels,
1452
+ temb_channels=temb_channels,
1453
+ eps=resnet_eps,
1454
+ groups=resnet_groups,
1455
+ dropout=dropout,
1456
+ time_embedding_norm=resnet_time_scale_shift,
1457
+ non_linearity=resnet_act_fn,
1458
+ output_scale_factor=output_scale_factor,
1459
+ pre_norm=resnet_pre_norm,
1460
+ )
1461
+ )
1462
+
1463
+ self.resnets = nn.ModuleList(resnets)
1464
+
1465
+ if add_upsample:
1466
+ self.upsamplers = nn.ModuleList(
1467
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1468
+ )
1469
+ else:
1470
+ self.upsamplers = None
1471
+
1472
+ self.gradient_checkpointing = False
1473
+ self.resolution_idx = resolution_idx
1474
+
1475
+ def forward(
1476
+ self,
1477
+ hidden_states: torch.FloatTensor,
1478
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1479
+ temb: Optional[torch.FloatTensor] = None,
1480
+ upsample_size: Optional[int] = None,
1481
+ scale: float = 1.0,
1482
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
1483
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
1484
+ ) -> torch.FloatTensor:
1485
+ is_freeu_enabled = (
1486
+ getattr(self, "s1", None)
1487
+ and getattr(self, "s2", None)
1488
+ and getattr(self, "b1", None)
1489
+ and getattr(self, "b2", None)
1490
+ )
1491
+
1492
+ for resnet in self.resnets:
1493
+ # pop res hidden states
1494
+ res_hidden_states = res_hidden_states_tuple[-1]
1495
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1496
+
1497
+ # FreeU: Only operate on the first two stages
1498
+ if is_freeu_enabled:
1499
+ hidden_states, res_hidden_states = apply_freeu(
1500
+ self.resolution_idx,
1501
+ hidden_states,
1502
+ res_hidden_states,
1503
+ s1=self.s1,
1504
+ s2=self.s2,
1505
+ b1=self.b1,
1506
+ b2=self.b2,
1507
+ )
1508
+
1509
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1510
+
1511
+ if self.training and self.gradient_checkpointing:
1512
+
1513
+ def create_custom_forward(module):
1514
+ def custom_forward(*inputs):
1515
+ return module(*inputs)
1516
+
1517
+ return custom_forward
1518
+
1519
+ if is_torch_version(">=", "1.11.0"):
1520
+ hidden_states = torch.utils.checkpoint.checkpoint(
1521
+ create_custom_forward(resnet),
1522
+ hidden_states,
1523
+ temb,
1524
+ use_reentrant=False,
1525
+ )
1526
+ else:
1527
+ hidden_states = torch.utils.checkpoint.checkpoint(
1528
+ create_custom_forward(resnet), hidden_states, temb
1529
+ )
1530
+ else:
1531
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1532
+
1533
+ if self.upsamplers is not None:
1534
+ for upsampler in self.upsamplers:
1535
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1536
+
1537
+ return hidden_states
musev/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,1413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/unet_3d_blocks.py
16
+
17
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
18
+ import logging
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+ from diffusers.utils import is_torch_version
24
+ from diffusers.models.transformer_2d import (
25
+ Transformer2DModel as DiffusersTransformer2DModel,
26
+ )
27
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
28
+ from ..data.data_util import batch_adain_conditioned_tensor
29
+
30
+ from .resnet import TemporalConvLayer
31
+ from .temporal_transformer import TransformerTemporalModel
32
+ from .transformer_2d import Transformer2DModel
33
+ from .attention_processor import ReferEmbFuseAttention
34
+
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # 注:
39
+ # (1) 原代码的`use_linear_projection`默认值均为True,与2D-SD模型不符,load时报错。因此均改为False
40
+ # (2) 原代码调用`Transformer2DModel`的输入参数顺序为n_channels // attn_num_head_channels, attn_num_head_channels,
41
+ # 与2D-SD模型不符。因此把顺序交换
42
+ # (3) 增加了temporal attention用的frame embedding输入
43
+
44
+ # note:
45
+ # 1. The default value of `use_linear_projection` in the original code is True, which is inconsistent with the 2D-SD model and causes an error when loading. Therefore, it is changed to False.
46
+ # 2. The original code calls `Transformer2DModel` with the input parameter order of n_channels // attn_num_head_channels, attn_num_head_channels, which is inconsistent with the 2D-SD model. Therefore, the order is reversed.
47
+ # 3. Added the frame embedding input used by the temporal attention
48
+
49
+
50
+ def get_down_block(
51
+ down_block_type,
52
+ num_layers,
53
+ in_channels,
54
+ out_channels,
55
+ temb_channels,
56
+ femb_channels,
57
+ add_downsample,
58
+ resnet_eps,
59
+ resnet_act_fn,
60
+ attn_num_head_channels,
61
+ resnet_groups=None,
62
+ cross_attention_dim=None,
63
+ downsample_padding=None,
64
+ dual_cross_attention=False,
65
+ use_linear_projection=False,
66
+ only_cross_attention=False,
67
+ upcast_attention=False,
68
+ resnet_time_scale_shift="default",
69
+ temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
70
+ temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
71
+ need_spatial_position_emb: bool = False,
72
+ need_t2i_ip_adapter: bool = False,
73
+ ip_adapter_cross_attn: bool = False,
74
+ need_t2i_facein: bool = False,
75
+ need_t2i_ip_adapter_face: bool = False,
76
+ need_adain_temporal_cond: bool = False,
77
+ resnet_2d_skip_time_act: bool = False,
78
+ need_refer_emb: bool = False,
79
+ ):
80
+ if (isinstance(down_block_type, str) and down_block_type == "DownBlock3D") or (
81
+ isinstance(down_block_type, nn.Module)
82
+ and down_block_type.__name__ == "DownBlock3D"
83
+ ):
84
+ return DownBlock3D(
85
+ num_layers=num_layers,
86
+ in_channels=in_channels,
87
+ out_channels=out_channels,
88
+ temb_channels=temb_channels,
89
+ femb_channels=femb_channels,
90
+ add_downsample=add_downsample,
91
+ resnet_eps=resnet_eps,
92
+ resnet_act_fn=resnet_act_fn,
93
+ resnet_groups=resnet_groups,
94
+ downsample_padding=downsample_padding,
95
+ resnet_time_scale_shift=resnet_time_scale_shift,
96
+ temporal_conv_block=temporal_conv_block,
97
+ need_adain_temporal_cond=need_adain_temporal_cond,
98
+ resnet_2d_skip_time_act=resnet_2d_skip_time_act,
99
+ need_refer_emb=need_refer_emb,
100
+ attn_num_head_channels=attn_num_head_channels,
101
+ )
102
+ elif (
103
+ isinstance(down_block_type, str) and down_block_type == "CrossAttnDownBlock3D"
104
+ ) or (
105
+ isinstance(down_block_type, nn.Module)
106
+ and down_block_type.__name__ == "CrossAttnDownBlock3D"
107
+ ):
108
+ if cross_attention_dim is None:
109
+ raise ValueError(
110
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
111
+ )
112
+ return CrossAttnDownBlock3D(
113
+ num_layers=num_layers,
114
+ in_channels=in_channels,
115
+ out_channels=out_channels,
116
+ temb_channels=temb_channels,
117
+ femb_channels=femb_channels,
118
+ add_downsample=add_downsample,
119
+ resnet_eps=resnet_eps,
120
+ resnet_act_fn=resnet_act_fn,
121
+ resnet_groups=resnet_groups,
122
+ downsample_padding=downsample_padding,
123
+ cross_attention_dim=cross_attention_dim,
124
+ attn_num_head_channels=attn_num_head_channels,
125
+ dual_cross_attention=dual_cross_attention,
126
+ use_linear_projection=use_linear_projection,
127
+ only_cross_attention=only_cross_attention,
128
+ upcast_attention=upcast_attention,
129
+ resnet_time_scale_shift=resnet_time_scale_shift,
130
+ temporal_conv_block=temporal_conv_block,
131
+ temporal_transformer=temporal_transformer,
132
+ need_spatial_position_emb=need_spatial_position_emb,
133
+ need_t2i_ip_adapter=need_t2i_ip_adapter,
134
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
135
+ need_t2i_facein=need_t2i_facein,
136
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
137
+ need_adain_temporal_cond=need_adain_temporal_cond,
138
+ resnet_2d_skip_time_act=resnet_2d_skip_time_act,
139
+ need_refer_emb=need_refer_emb,
140
+ )
141
+ raise ValueError(f"{down_block_type} does not exist.")
142
+
143
+
144
+ def get_up_block(
145
+ up_block_type,
146
+ num_layers,
147
+ in_channels,
148
+ out_channels,
149
+ prev_output_channel,
150
+ temb_channels,
151
+ femb_channels,
152
+ add_upsample,
153
+ resnet_eps,
154
+ resnet_act_fn,
155
+ attn_num_head_channels,
156
+ resnet_groups=None,
157
+ cross_attention_dim=None,
158
+ dual_cross_attention=False,
159
+ use_linear_projection=False,
160
+ only_cross_attention=False,
161
+ upcast_attention=False,
162
+ resnet_time_scale_shift="default",
163
+ temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
164
+ temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
165
+ need_spatial_position_emb: bool = False,
166
+ need_t2i_ip_adapter: bool = False,
167
+ ip_adapter_cross_attn: bool = False,
168
+ need_t2i_facein: bool = False,
169
+ need_t2i_ip_adapter_face: bool = False,
170
+ need_adain_temporal_cond: bool = False,
171
+ resnet_2d_skip_time_act: bool = False,
172
+ ):
173
+ if (isinstance(up_block_type, str) and up_block_type == "UpBlock3D") or (
174
+ isinstance(up_block_type, nn.Module) and up_block_type.__name__ == "UpBlock3D"
175
+ ):
176
+ return UpBlock3D(
177
+ num_layers=num_layers,
178
+ in_channels=in_channels,
179
+ out_channels=out_channels,
180
+ prev_output_channel=prev_output_channel,
181
+ temb_channels=temb_channels,
182
+ femb_channels=femb_channels,
183
+ add_upsample=add_upsample,
184
+ resnet_eps=resnet_eps,
185
+ resnet_act_fn=resnet_act_fn,
186
+ resnet_groups=resnet_groups,
187
+ resnet_time_scale_shift=resnet_time_scale_shift,
188
+ temporal_conv_block=temporal_conv_block,
189
+ need_adain_temporal_cond=need_adain_temporal_cond,
190
+ resnet_2d_skip_time_act=resnet_2d_skip_time_act,
191
+ )
192
+ elif (isinstance(up_block_type, str) and up_block_type == "CrossAttnUpBlock3D") or (
193
+ isinstance(up_block_type, nn.Module)
194
+ and up_block_type.__name__ == "CrossAttnUpBlock3D"
195
+ ):
196
+ if cross_attention_dim is None:
197
+ raise ValueError(
198
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
199
+ )
200
+ return CrossAttnUpBlock3D(
201
+ num_layers=num_layers,
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ prev_output_channel=prev_output_channel,
205
+ temb_channels=temb_channels,
206
+ femb_channels=femb_channels,
207
+ add_upsample=add_upsample,
208
+ resnet_eps=resnet_eps,
209
+ resnet_act_fn=resnet_act_fn,
210
+ resnet_groups=resnet_groups,
211
+ cross_attention_dim=cross_attention_dim,
212
+ attn_num_head_channels=attn_num_head_channels,
213
+ dual_cross_attention=dual_cross_attention,
214
+ use_linear_projection=use_linear_projection,
215
+ only_cross_attention=only_cross_attention,
216
+ upcast_attention=upcast_attention,
217
+ resnet_time_scale_shift=resnet_time_scale_shift,
218
+ temporal_conv_block=temporal_conv_block,
219
+ temporal_transformer=temporal_transformer,
220
+ need_spatial_position_emb=need_spatial_position_emb,
221
+ need_t2i_ip_adapter=need_t2i_ip_adapter,
222
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
223
+ need_t2i_facein=need_t2i_facein,
224
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
225
+ need_adain_temporal_cond=need_adain_temporal_cond,
226
+ resnet_2d_skip_time_act=resnet_2d_skip_time_act,
227
+ )
228
+ raise ValueError(f"{up_block_type} does not exist.")
229
+
230
+
231
+ class UNetMidBlock3DCrossAttn(nn.Module):
232
+ print_idx = 0
233
+
234
+ def __init__(
235
+ self,
236
+ in_channels: int,
237
+ temb_channels: int,
238
+ femb_channels: int,
239
+ dropout: float = 0.0,
240
+ num_layers: int = 1,
241
+ resnet_eps: float = 1e-6,
242
+ resnet_time_scale_shift: str = "default",
243
+ resnet_act_fn: str = "swish",
244
+ resnet_groups: int = 32,
245
+ resnet_pre_norm: bool = True,
246
+ attn_num_head_channels=1,
247
+ output_scale_factor=1.0,
248
+ cross_attention_dim=1280,
249
+ dual_cross_attention=False,
250
+ use_linear_projection=False,
251
+ upcast_attention=False,
252
+ temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
253
+ temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
254
+ need_spatial_position_emb: bool = False,
255
+ need_t2i_ip_adapter: bool = False,
256
+ ip_adapter_cross_attn: bool = False,
257
+ need_t2i_facein: bool = False,
258
+ need_t2i_ip_adapter_face: bool = False,
259
+ need_adain_temporal_cond: bool = False,
260
+ resnet_2d_skip_time_act: bool = False,
261
+ ):
262
+ super().__init__()
263
+
264
+ self.has_cross_attention = True
265
+ self.attn_num_head_channels = attn_num_head_channels
266
+ resnet_groups = (
267
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
268
+ )
269
+
270
+ # there is always at least one resnet
271
+ resnets = [
272
+ ResnetBlock2D(
273
+ in_channels=in_channels,
274
+ out_channels=in_channels,
275
+ temb_channels=temb_channels,
276
+ eps=resnet_eps,
277
+ groups=resnet_groups,
278
+ dropout=dropout,
279
+ time_embedding_norm=resnet_time_scale_shift,
280
+ non_linearity=resnet_act_fn,
281
+ output_scale_factor=output_scale_factor,
282
+ pre_norm=resnet_pre_norm,
283
+ skip_time_act=resnet_2d_skip_time_act,
284
+ )
285
+ ]
286
+ if temporal_conv_block is not None:
287
+ temp_convs = [
288
+ temporal_conv_block(
289
+ in_channels,
290
+ in_channels,
291
+ dropout=0.1,
292
+ femb_channels=femb_channels,
293
+ )
294
+ ]
295
+ else:
296
+ temp_convs = [None]
297
+ attentions = []
298
+ temp_attentions = []
299
+
300
+ for _ in range(num_layers):
301
+ attentions.append(
302
+ Transformer2DModel(
303
+ attn_num_head_channels,
304
+ in_channels // attn_num_head_channels,
305
+ in_channels=in_channels,
306
+ num_layers=1,
307
+ cross_attention_dim=cross_attention_dim,
308
+ norm_num_groups=resnet_groups,
309
+ use_linear_projection=use_linear_projection,
310
+ upcast_attention=upcast_attention,
311
+ cross_attn_temporal_cond=need_t2i_ip_adapter,
312
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
313
+ need_t2i_facein=need_t2i_facein,
314
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
315
+ )
316
+ )
317
+ if temporal_transformer is not None:
318
+ temp_attention = temporal_transformer(
319
+ attn_num_head_channels,
320
+ in_channels // attn_num_head_channels,
321
+ in_channels=in_channels,
322
+ num_layers=1,
323
+ femb_channels=femb_channels,
324
+ cross_attention_dim=cross_attention_dim,
325
+ norm_num_groups=resnet_groups,
326
+ need_spatial_position_emb=need_spatial_position_emb,
327
+ )
328
+ else:
329
+ temp_attention = None
330
+ temp_attentions.append(temp_attention)
331
+ resnets.append(
332
+ ResnetBlock2D(
333
+ in_channels=in_channels,
334
+ out_channels=in_channels,
335
+ temb_channels=temb_channels,
336
+ eps=resnet_eps,
337
+ groups=resnet_groups,
338
+ dropout=dropout,
339
+ time_embedding_norm=resnet_time_scale_shift,
340
+ non_linearity=resnet_act_fn,
341
+ output_scale_factor=output_scale_factor,
342
+ pre_norm=resnet_pre_norm,
343
+ skip_time_act=resnet_2d_skip_time_act,
344
+ )
345
+ )
346
+ if temporal_conv_block is not None:
347
+ temp_convs.append(
348
+ temporal_conv_block(
349
+ in_channels,
350
+ in_channels,
351
+ dropout=0.1,
352
+ femb_channels=femb_channels,
353
+ )
354
+ )
355
+ else:
356
+ temp_convs.append(None)
357
+
358
+ self.resnets = nn.ModuleList(resnets)
359
+ self.temp_convs = nn.ModuleList(temp_convs)
360
+ self.attentions = nn.ModuleList(attentions)
361
+ self.temp_attentions = nn.ModuleList(temp_attentions)
362
+ self.need_adain_temporal_cond = need_adain_temporal_cond
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states,
367
+ temb=None,
368
+ femb=None,
369
+ encoder_hidden_states=None,
370
+ attention_mask=None,
371
+ num_frames=1,
372
+ cross_attention_kwargs=None,
373
+ sample_index: torch.LongTensor = None,
374
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
375
+ spatial_position_emb: torch.Tensor = None,
376
+ refer_self_attn_emb: List[torch.Tensor] = None,
377
+ refer_self_attn_emb_mode: Literal["read", "write"] = "read",
378
+ ):
379
+ hidden_states = self.resnets[0](hidden_states, temb)
380
+ if self.temp_convs[0] is not None:
381
+ hidden_states = self.temp_convs[0](
382
+ hidden_states,
383
+ femb=femb,
384
+ num_frames=num_frames,
385
+ sample_index=sample_index,
386
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
387
+ )
388
+ for attn, temp_attn, resnet, temp_conv in zip(
389
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
390
+ ):
391
+ hidden_states = attn(
392
+ hidden_states,
393
+ encoder_hidden_states=encoder_hidden_states,
394
+ cross_attention_kwargs=cross_attention_kwargs,
395
+ self_attn_block_embs=refer_self_attn_emb,
396
+ self_attn_block_embs_mode=refer_self_attn_emb_mode,
397
+ ).sample
398
+ if temp_attn is not None:
399
+ hidden_states = temp_attn(
400
+ hidden_states,
401
+ femb=femb,
402
+ num_frames=num_frames,
403
+ cross_attention_kwargs=cross_attention_kwargs,
404
+ encoder_hidden_states=encoder_hidden_states,
405
+ sample_index=sample_index,
406
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
407
+ spatial_position_emb=spatial_position_emb,
408
+ ).sample
409
+ hidden_states = resnet(hidden_states, temb)
410
+ if temp_conv is not None:
411
+ hidden_states = temp_conv(
412
+ hidden_states,
413
+ femb=femb,
414
+ num_frames=num_frames,
415
+ sample_index=sample_index,
416
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
417
+ )
418
+ if (
419
+ self.need_adain_temporal_cond
420
+ and num_frames > 1
421
+ and sample_index is not None
422
+ ):
423
+ if self.print_idx == 0:
424
+ logger.debug(f"adain to vision_condition")
425
+ hidden_states = batch_adain_conditioned_tensor(
426
+ hidden_states,
427
+ num_frames=num_frames,
428
+ need_style_fidelity=False,
429
+ src_index=sample_index,
430
+ dst_index=vision_conditon_frames_sample_index,
431
+ )
432
+ self.print_idx += 1
433
+ return hidden_states
434
+
435
+
436
+ class CrossAttnDownBlock3D(nn.Module):
437
+ print_idx = 0
438
+
439
+ def __init__(
440
+ self,
441
+ in_channels: int,
442
+ out_channels: int,
443
+ temb_channels: int,
444
+ femb_channels: int,
445
+ dropout: float = 0.0,
446
+ num_layers: int = 1,
447
+ resnet_eps: float = 1e-6,
448
+ resnet_time_scale_shift: str = "default",
449
+ resnet_act_fn: str = "swish",
450
+ resnet_groups: int = 32,
451
+ resnet_pre_norm: bool = True,
452
+ attn_num_head_channels=1,
453
+ cross_attention_dim=1280,
454
+ output_scale_factor=1.0,
455
+ downsample_padding=1,
456
+ add_downsample=True,
457
+ dual_cross_attention=False,
458
+ use_linear_projection=False,
459
+ only_cross_attention=False,
460
+ upcast_attention=False,
461
+ temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
462
+ temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
463
+ need_spatial_position_emb: bool = False,
464
+ need_t2i_ip_adapter: bool = False,
465
+ ip_adapter_cross_attn: bool = False,
466
+ need_t2i_facein: bool = False,
467
+ need_t2i_ip_adapter_face: bool = False,
468
+ need_adain_temporal_cond: bool = False,
469
+ resnet_2d_skip_time_act: bool = False,
470
+ need_refer_emb: bool = False,
471
+ ):
472
+ super().__init__()
473
+ resnets = []
474
+ attentions = []
475
+ temp_attentions = []
476
+ temp_convs = []
477
+
478
+ self.has_cross_attention = True
479
+ self.attn_num_head_channels = attn_num_head_channels
480
+ self.need_refer_emb = need_refer_emb
481
+ if need_refer_emb:
482
+ refer_emb_attns = []
483
+ for i in range(num_layers):
484
+ in_channels = in_channels if i == 0 else out_channels
485
+ resnets.append(
486
+ ResnetBlock2D(
487
+ in_channels=in_channels,
488
+ out_channels=out_channels,
489
+ temb_channels=temb_channels,
490
+ eps=resnet_eps,
491
+ groups=resnet_groups,
492
+ dropout=dropout,
493
+ time_embedding_norm=resnet_time_scale_shift,
494
+ non_linearity=resnet_act_fn,
495
+ output_scale_factor=output_scale_factor,
496
+ pre_norm=resnet_pre_norm,
497
+ skip_time_act=resnet_2d_skip_time_act,
498
+ )
499
+ )
500
+ if temporal_conv_block is not None:
501
+ temp_convs.append(
502
+ temporal_conv_block(
503
+ out_channels,
504
+ out_channels,
505
+ dropout=0.1,
506
+ femb_channels=femb_channels,
507
+ )
508
+ )
509
+ else:
510
+ temp_convs.append(None)
511
+ attentions.append(
512
+ Transformer2DModel(
513
+ attn_num_head_channels,
514
+ out_channels // attn_num_head_channels,
515
+ in_channels=out_channels,
516
+ num_layers=1,
517
+ cross_attention_dim=cross_attention_dim,
518
+ norm_num_groups=resnet_groups,
519
+ use_linear_projection=use_linear_projection,
520
+ only_cross_attention=only_cross_attention,
521
+ upcast_attention=upcast_attention,
522
+ cross_attn_temporal_cond=need_t2i_ip_adapter,
523
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
524
+ need_t2i_facein=need_t2i_facein,
525
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
526
+ )
527
+ )
528
+ if temporal_transformer is not None:
529
+ temp_attention = temporal_transformer(
530
+ attn_num_head_channels,
531
+ out_channels // attn_num_head_channels,
532
+ in_channels=out_channels,
533
+ num_layers=1,
534
+ femb_channels=femb_channels,
535
+ cross_attention_dim=cross_attention_dim,
536
+ norm_num_groups=resnet_groups,
537
+ need_spatial_position_emb=need_spatial_position_emb,
538
+ )
539
+ else:
540
+ temp_attention = None
541
+ temp_attentions.append(temp_attention)
542
+
543
+ if need_refer_emb:
544
+ refer_emb_attns.append(
545
+ ReferEmbFuseAttention(
546
+ query_dim=out_channels,
547
+ heads=attn_num_head_channels,
548
+ dim_head=out_channels // attn_num_head_channels,
549
+ dropout=0,
550
+ bias=False,
551
+ cross_attention_dim=None,
552
+ upcast_attention=False,
553
+ )
554
+ )
555
+
556
+ self.resnets = nn.ModuleList(resnets)
557
+ self.temp_convs = nn.ModuleList(temp_convs)
558
+ self.attentions = nn.ModuleList(attentions)
559
+ self.temp_attentions = nn.ModuleList(temp_attentions)
560
+
561
+ if add_downsample:
562
+ self.downsamplers = nn.ModuleList(
563
+ [
564
+ Downsample2D(
565
+ out_channels,
566
+ use_conv=True,
567
+ out_channels=out_channels,
568
+ padding=downsample_padding,
569
+ name="op",
570
+ )
571
+ ]
572
+ )
573
+ if need_refer_emb:
574
+ refer_emb_attns.append(
575
+ ReferEmbFuseAttention(
576
+ query_dim=out_channels,
577
+ heads=attn_num_head_channels,
578
+ dim_head=out_channels // attn_num_head_channels,
579
+ dropout=0,
580
+ bias=False,
581
+ cross_attention_dim=None,
582
+ upcast_attention=False,
583
+ )
584
+ )
585
+ else:
586
+ self.downsamplers = None
587
+
588
+ self.gradient_checkpointing = False
589
+ self.need_adain_temporal_cond = need_adain_temporal_cond
590
+ if need_refer_emb:
591
+ self.refer_emb_attns = nn.ModuleList(refer_emb_attns)
592
+ logger.debug(f"cross attn downblock 3d need_refer_emb, {self.need_refer_emb}")
593
+
594
+ def forward(
595
+ self,
596
+ hidden_states: torch.FloatTensor,
597
+ temb: Optional[torch.FloatTensor] = None,
598
+ femb: Optional[torch.FloatTensor] = None,
599
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
600
+ attention_mask: Optional[torch.FloatTensor] = None,
601
+ num_frames: int = 1,
602
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
603
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
604
+ sample_index: torch.LongTensor = None,
605
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
606
+ spatial_position_emb: torch.Tensor = None,
607
+ refer_embs: Optional[List[torch.Tensor]] = None,
608
+ refer_self_attn_emb: List[torch.Tensor] = None,
609
+ refer_self_attn_emb_mode: Literal["read", "write"] = "read",
610
+ ):
611
+ # TODO(Patrick, William) - attention mask is not used
612
+ output_states = ()
613
+ for i_downblock, (resnet, temp_conv, attn, temp_attn) in enumerate(
614
+ zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions)
615
+ ):
616
+ # print("crossattndownblock3d, attn,", type(attn), cross_attention_kwargs)
617
+ if self.training and self.gradient_checkpointing:
618
+ if self.print_idx == 0:
619
+ logger.debug(
620
+ f"self.training and self.gradient_checkpointing={self.training and self.gradient_checkpointing}"
621
+ )
622
+
623
+ def create_custom_forward(module, return_dict=None):
624
+ def custom_forward(*inputs):
625
+ if return_dict is not None:
626
+ return module(*inputs, return_dict=return_dict)
627
+ else:
628
+ return module(*inputs)
629
+
630
+ return custom_forward
631
+
632
+ ckpt_kwargs: Dict[str, Any] = (
633
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
634
+ )
635
+ hidden_states = torch.utils.checkpoint.checkpoint(
636
+ create_custom_forward(resnet),
637
+ hidden_states,
638
+ temb,
639
+ **ckpt_kwargs,
640
+ )
641
+ if self.print_idx == 0:
642
+ logger.debug(f"unet3d after resnet {hidden_states.mean()}")
643
+ if temp_conv is not None:
644
+ hidden_states = torch.utils.checkpoint.checkpoint(
645
+ create_custom_forward(temp_conv),
646
+ hidden_states,
647
+ num_frames,
648
+ sample_index,
649
+ vision_conditon_frames_sample_index,
650
+ femb,
651
+ **ckpt_kwargs,
652
+ )
653
+ hidden_states = torch.utils.checkpoint.checkpoint(
654
+ create_custom_forward(attn, return_dict=False),
655
+ hidden_states,
656
+ encoder_hidden_states,
657
+ None, # timestep
658
+ None, # added_cond_kwargs
659
+ None, # class_labels
660
+ cross_attention_kwargs,
661
+ attention_mask,
662
+ encoder_attention_mask,
663
+ refer_self_attn_emb,
664
+ refer_self_attn_emb_mode,
665
+ **ckpt_kwargs,
666
+ )[0]
667
+ if temp_attn is not None:
668
+ hidden_states = torch.utils.checkpoint.checkpoint(
669
+ create_custom_forward(temp_attn, return_dict=False),
670
+ hidden_states,
671
+ femb,
672
+ # None, # encoder_hidden_states,
673
+ encoder_hidden_states,
674
+ None, # timestep
675
+ None, # class_labels
676
+ num_frames,
677
+ cross_attention_kwargs,
678
+ sample_index,
679
+ vision_conditon_frames_sample_index,
680
+ spatial_position_emb,
681
+ **ckpt_kwargs,
682
+ )[0]
683
+ else:
684
+ hidden_states = resnet(hidden_states, temb)
685
+ if self.print_idx == 0:
686
+ logger.debug(f"unet3d after resnet {hidden_states.mean()}")
687
+ if temp_conv is not None:
688
+ hidden_states = temp_conv(
689
+ hidden_states,
690
+ femb=femb,
691
+ num_frames=num_frames,
692
+ sample_index=sample_index,
693
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
694
+ )
695
+ hidden_states = attn(
696
+ hidden_states,
697
+ encoder_hidden_states=encoder_hidden_states,
698
+ cross_attention_kwargs=cross_attention_kwargs,
699
+ self_attn_block_embs=refer_self_attn_emb,
700
+ self_attn_block_embs_mode=refer_self_attn_emb_mode,
701
+ ).sample
702
+ if temp_attn is not None:
703
+ hidden_states = temp_attn(
704
+ hidden_states,
705
+ femb=femb,
706
+ num_frames=num_frames,
707
+ encoder_hidden_states=encoder_hidden_states,
708
+ cross_attention_kwargs=cross_attention_kwargs,
709
+ sample_index=sample_index,
710
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
711
+ spatial_position_emb=spatial_position_emb,
712
+ ).sample
713
+ if (
714
+ self.need_adain_temporal_cond
715
+ and num_frames > 1
716
+ and sample_index is not None
717
+ ):
718
+ if self.print_idx == 0:
719
+ logger.debug(f"adain to vision_condition")
720
+ hidden_states = batch_adain_conditioned_tensor(
721
+ hidden_states,
722
+ num_frames=num_frames,
723
+ need_style_fidelity=False,
724
+ src_index=sample_index,
725
+ dst_index=vision_conditon_frames_sample_index,
726
+ )
727
+ # 使用 attn 的方式 来融合 down_block_refer_emb
728
+ if self.print_idx == 0:
729
+ logger.debug(
730
+ f"downblock, {i_downblock}, self.need_refer_emb={self.need_refer_emb}"
731
+ )
732
+ if self.need_refer_emb and refer_embs is not None:
733
+ if self.print_idx == 0:
734
+ logger.debug(
735
+ f"{i_downblock}, self.refer_emb_attns {refer_embs[i_downblock].shape}"
736
+ )
737
+ hidden_states = self.refer_emb_attns[i_downblock](
738
+ hidden_states, refer_embs[i_downblock], num_frames=num_frames
739
+ )
740
+ else:
741
+ if self.print_idx == 0:
742
+ logger.debug(f"crossattndownblock refer_emb_attns, no this step")
743
+ output_states += (hidden_states,)
744
+
745
+ if self.downsamplers is not None:
746
+ for downsampler in self.downsamplers:
747
+ hidden_states = downsampler(hidden_states)
748
+ if (
749
+ self.need_adain_temporal_cond
750
+ and num_frames > 1
751
+ and sample_index is not None
752
+ ):
753
+ if self.print_idx == 0:
754
+ logger.debug(f"adain to vision_condition")
755
+ hidden_states = batch_adain_conditioned_tensor(
756
+ hidden_states,
757
+ num_frames=num_frames,
758
+ need_style_fidelity=False,
759
+ src_index=sample_index,
760
+ dst_index=vision_conditon_frames_sample_index,
761
+ )
762
+ # 使用 attn 的方式 来融合 down_block_refer_emb
763
+ # TODO: adain和 refer_emb的顺序
764
+ # TODO:adain 首帧特征还是refer_emb的
765
+ if self.need_refer_emb and refer_embs is not None:
766
+ i_downblock += 1
767
+ hidden_states = self.refer_emb_attns[i_downblock](
768
+ hidden_states, refer_embs[i_downblock], num_frames=num_frames
769
+ )
770
+ output_states += (hidden_states,)
771
+ self.print_idx += 1
772
+ return hidden_states, output_states
773
+
774
+
775
+ class DownBlock3D(nn.Module):
776
+ print_idx = 0
777
+
778
+ def __init__(
779
+ self,
780
+ in_channels: int,
781
+ out_channels: int,
782
+ temb_channels: int,
783
+ femb_channels: int,
784
+ dropout: float = 0.0,
785
+ num_layers: int = 1,
786
+ resnet_eps: float = 1e-6,
787
+ resnet_time_scale_shift: str = "default",
788
+ resnet_act_fn: str = "swish",
789
+ resnet_groups: int = 32,
790
+ resnet_pre_norm: bool = True,
791
+ output_scale_factor=1.0,
792
+ add_downsample=True,
793
+ downsample_padding=1,
794
+ temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
795
+ need_adain_temporal_cond: bool = False,
796
+ resnet_2d_skip_time_act: bool = False,
797
+ need_refer_emb: bool = False,
798
+ attn_num_head_channels: int = 1,
799
+ ):
800
+ super().__init__()
801
+ resnets = []
802
+ temp_convs = []
803
+ self.need_refer_emb = need_refer_emb
804
+ if need_refer_emb:
805
+ refer_emb_attns = []
806
+ self.attn_num_head_channels = attn_num_head_channels
807
+
808
+ for i in range(num_layers):
809
+ in_channels = in_channels if i == 0 else out_channels
810
+ resnets.append(
811
+ ResnetBlock2D(
812
+ in_channels=in_channels,
813
+ out_channels=out_channels,
814
+ temb_channels=temb_channels,
815
+ eps=resnet_eps,
816
+ groups=resnet_groups,
817
+ dropout=dropout,
818
+ time_embedding_norm=resnet_time_scale_shift,
819
+ non_linearity=resnet_act_fn,
820
+ output_scale_factor=output_scale_factor,
821
+ pre_norm=resnet_pre_norm,
822
+ skip_time_act=resnet_2d_skip_time_act,
823
+ )
824
+ )
825
+ if temporal_conv_block is not None:
826
+ temp_convs.append(
827
+ temporal_conv_block(
828
+ out_channels,
829
+ out_channels,
830
+ dropout=0.1,
831
+ femb_channels=femb_channels,
832
+ )
833
+ )
834
+ else:
835
+ temp_convs.append(None)
836
+ if need_refer_emb:
837
+ refer_emb_attns.append(
838
+ ReferEmbFuseAttention(
839
+ query_dim=out_channels,
840
+ heads=attn_num_head_channels,
841
+ dim_head=out_channels // attn_num_head_channels,
842
+ dropout=0,
843
+ bias=False,
844
+ cross_attention_dim=None,
845
+ upcast_attention=False,
846
+ )
847
+ )
848
+
849
+ self.resnets = nn.ModuleList(resnets)
850
+ self.temp_convs = nn.ModuleList(temp_convs)
851
+
852
+ if add_downsample:
853
+ self.downsamplers = nn.ModuleList(
854
+ [
855
+ Downsample2D(
856
+ out_channels,
857
+ use_conv=True,
858
+ out_channels=out_channels,
859
+ padding=downsample_padding,
860
+ name="op",
861
+ )
862
+ ]
863
+ )
864
+ if need_refer_emb:
865
+ refer_emb_attns.append(
866
+ ReferEmbFuseAttention(
867
+ query_dim=out_channels,
868
+ heads=attn_num_head_channels,
869
+ dim_head=out_channels // attn_num_head_channels,
870
+ dropout=0,
871
+ bias=False,
872
+ cross_attention_dim=None,
873
+ upcast_attention=False,
874
+ )
875
+ )
876
+ else:
877
+ self.downsamplers = None
878
+
879
+ self.gradient_checkpointing = False
880
+ self.need_adain_temporal_cond = need_adain_temporal_cond
881
+ if need_refer_emb:
882
+ self.refer_emb_attns = nn.ModuleList(refer_emb_attns)
883
+
884
+ def forward(
885
+ self,
886
+ hidden_states,
887
+ temb=None,
888
+ num_frames=1,
889
+ sample_index: torch.LongTensor = None,
890
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
891
+ spatial_position_emb: torch.Tensor = None,
892
+ femb=None,
893
+ refer_embs: Optional[Tuple[torch.Tensor]] = None,
894
+ refer_self_attn_emb: List[torch.Tensor] = None,
895
+ refer_self_attn_emb_mode: Literal["read", "write"] = "read",
896
+ ):
897
+ output_states = ()
898
+
899
+ for i_downblock, (resnet, temp_conv) in enumerate(
900
+ zip(self.resnets, self.temp_convs)
901
+ ):
902
+ if self.training and self.gradient_checkpointing:
903
+
904
+ def create_custom_forward(module):
905
+ def custom_forward(*inputs):
906
+ return module(*inputs)
907
+
908
+ return custom_forward
909
+
910
+ ckpt_kwargs: Dict[str, Any] = (
911
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
912
+ )
913
+ hidden_states = torch.utils.checkpoint.checkpoint(
914
+ create_custom_forward(resnet),
915
+ hidden_states,
916
+ temb,
917
+ **ckpt_kwargs,
918
+ )
919
+ if temp_conv is not None:
920
+ hidden_states = torch.utils.checkpoint.checkpoint(
921
+ create_custom_forward(temp_conv),
922
+ hidden_states,
923
+ num_frames,
924
+ sample_index,
925
+ vision_conditon_frames_sample_index,
926
+ femb,
927
+ **ckpt_kwargs,
928
+ )
929
+ else:
930
+ hidden_states = resnet(hidden_states, temb)
931
+ if temp_conv is not None:
932
+ hidden_states = temp_conv(
933
+ hidden_states,
934
+ femb=femb,
935
+ num_frames=num_frames,
936
+ sample_index=sample_index,
937
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
938
+ )
939
+ if (
940
+ self.need_adain_temporal_cond
941
+ and num_frames > 1
942
+ and sample_index is not None
943
+ ):
944
+ if self.print_idx == 0:
945
+ logger.debug(f"adain to vision_condition")
946
+ hidden_states = batch_adain_conditioned_tensor(
947
+ hidden_states,
948
+ num_frames=num_frames,
949
+ need_style_fidelity=False,
950
+ src_index=sample_index,
951
+ dst_index=vision_conditon_frames_sample_index,
952
+ )
953
+ if self.need_refer_emb and refer_embs is not None:
954
+ hidden_states = self.refer_emb_attns[i_downblock](
955
+ hidden_states, refer_embs[i_downblock], num_frames=num_frames
956
+ )
957
+ output_states += (hidden_states,)
958
+
959
+ if self.downsamplers is not None:
960
+ for downsampler in self.downsamplers:
961
+ hidden_states = downsampler(hidden_states)
962
+ if (
963
+ self.need_adain_temporal_cond
964
+ and num_frames > 1
965
+ and sample_index is not None
966
+ ):
967
+ if self.print_idx == 0:
968
+ logger.debug(f"adain to vision_condition")
969
+ hidden_states = batch_adain_conditioned_tensor(
970
+ hidden_states,
971
+ num_frames=num_frames,
972
+ need_style_fidelity=False,
973
+ src_index=sample_index,
974
+ dst_index=vision_conditon_frames_sample_index,
975
+ )
976
+ if self.need_refer_emb and refer_embs is not None:
977
+ i_downblock += 1
978
+ hidden_states = self.refer_emb_attns[i_downblock](
979
+ hidden_states, refer_embs[i_downblock], num_frames=num_frames
980
+ )
981
+ output_states += (hidden_states,)
982
+ self.print_idx += 1
983
+ return hidden_states, output_states
984
+
985
+
986
+ class CrossAttnUpBlock3D(nn.Module):
987
+ print_idx = 0
988
+
989
+ def __init__(
990
+ self,
991
+ in_channels: int,
992
+ out_channels: int,
993
+ prev_output_channel: int,
994
+ temb_channels: int,
995
+ femb_channels: int,
996
+ dropout: float = 0.0,
997
+ num_layers: int = 1,
998
+ resnet_eps: float = 1e-6,
999
+ resnet_time_scale_shift: str = "default",
1000
+ resnet_act_fn: str = "swish",
1001
+ resnet_groups: int = 32,
1002
+ resnet_pre_norm: bool = True,
1003
+ attn_num_head_channels=1,
1004
+ cross_attention_dim=1280,
1005
+ output_scale_factor=1.0,
1006
+ add_upsample=True,
1007
+ dual_cross_attention=False,
1008
+ use_linear_projection=False,
1009
+ only_cross_attention=False,
1010
+ upcast_attention=False,
1011
+ temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
1012
+ temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
1013
+ need_spatial_position_emb: bool = False,
1014
+ need_t2i_ip_adapter: bool = False,
1015
+ ip_adapter_cross_attn: bool = False,
1016
+ need_t2i_facein: bool = False,
1017
+ need_t2i_ip_adapter_face: bool = False,
1018
+ need_adain_temporal_cond: bool = False,
1019
+ resnet_2d_skip_time_act: bool = False,
1020
+ ):
1021
+ super().__init__()
1022
+ resnets = []
1023
+ temp_convs = []
1024
+ attentions = []
1025
+ temp_attentions = []
1026
+
1027
+ self.has_cross_attention = True
1028
+ self.attn_num_head_channels = attn_num_head_channels
1029
+
1030
+ for i in range(num_layers):
1031
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1032
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1033
+
1034
+ resnets.append(
1035
+ ResnetBlock2D(
1036
+ in_channels=resnet_in_channels + res_skip_channels,
1037
+ out_channels=out_channels,
1038
+ temb_channels=temb_channels,
1039
+ eps=resnet_eps,
1040
+ groups=resnet_groups,
1041
+ dropout=dropout,
1042
+ time_embedding_norm=resnet_time_scale_shift,
1043
+ non_linearity=resnet_act_fn,
1044
+ output_scale_factor=output_scale_factor,
1045
+ pre_norm=resnet_pre_norm,
1046
+ skip_time_act=resnet_2d_skip_time_act,
1047
+ )
1048
+ )
1049
+ if temporal_conv_block is not None:
1050
+ temp_convs.append(
1051
+ temporal_conv_block(
1052
+ out_channels,
1053
+ out_channels,
1054
+ dropout=0.1,
1055
+ femb_channels=femb_channels,
1056
+ )
1057
+ )
1058
+ else:
1059
+ temp_convs.append(None)
1060
+ attentions.append(
1061
+ Transformer2DModel(
1062
+ attn_num_head_channels,
1063
+ out_channels // attn_num_head_channels,
1064
+ in_channels=out_channels,
1065
+ num_layers=1,
1066
+ cross_attention_dim=cross_attention_dim,
1067
+ norm_num_groups=resnet_groups,
1068
+ use_linear_projection=use_linear_projection,
1069
+ only_cross_attention=only_cross_attention,
1070
+ upcast_attention=upcast_attention,
1071
+ cross_attn_temporal_cond=need_t2i_ip_adapter,
1072
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
1073
+ need_t2i_facein=need_t2i_facein,
1074
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
1075
+ )
1076
+ )
1077
+ if temporal_transformer is not None:
1078
+ temp_attention = temporal_transformer(
1079
+ attn_num_head_channels,
1080
+ out_channels // attn_num_head_channels,
1081
+ in_channels=out_channels,
1082
+ num_layers=1,
1083
+ femb_channels=femb_channels,
1084
+ cross_attention_dim=cross_attention_dim,
1085
+ norm_num_groups=resnet_groups,
1086
+ need_spatial_position_emb=need_spatial_position_emb,
1087
+ )
1088
+ else:
1089
+ temp_attention = None
1090
+ temp_attentions.append(temp_attention)
1091
+ self.resnets = nn.ModuleList(resnets)
1092
+ self.temp_convs = nn.ModuleList(temp_convs)
1093
+ self.attentions = nn.ModuleList(attentions)
1094
+ self.temp_attentions = nn.ModuleList(temp_attentions)
1095
+
1096
+ if add_upsample:
1097
+ self.upsamplers = nn.ModuleList(
1098
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1099
+ )
1100
+ else:
1101
+ self.upsamplers = None
1102
+
1103
+ self.gradient_checkpointing = False
1104
+ self.need_adain_temporal_cond = need_adain_temporal_cond
1105
+
1106
+ def forward(
1107
+ self,
1108
+ hidden_states: torch.FloatTensor,
1109
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1110
+ temb: Optional[torch.FloatTensor] = None,
1111
+ femb: Optional[torch.FloatTensor] = None,
1112
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1113
+ num_frames: int = 1,
1114
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1115
+ upsample_size: Optional[int] = None,
1116
+ attention_mask: Optional[torch.FloatTensor] = None,
1117
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1118
+ sample_index: torch.LongTensor = None,
1119
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
1120
+ spatial_position_emb: torch.Tensor = None,
1121
+ refer_self_attn_emb: List[torch.Tensor] = None,
1122
+ refer_self_attn_emb_mode: Literal["read", "write"] = "read",
1123
+ ):
1124
+ for resnet, temp_conv, attn, temp_attn in zip(
1125
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
1126
+ ):
1127
+ # pop res hidden states
1128
+ res_hidden_states = res_hidden_states_tuple[-1]
1129
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1130
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1131
+ if self.training and self.gradient_checkpointing:
1132
+
1133
+ def create_custom_forward(module, return_dict=None):
1134
+ def custom_forward(*inputs):
1135
+ if return_dict is not None:
1136
+ return module(*inputs, return_dict=return_dict)
1137
+ else:
1138
+ return module(*inputs)
1139
+
1140
+ return custom_forward
1141
+
1142
+ ckpt_kwargs: Dict[str, Any] = (
1143
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1144
+ )
1145
+ hidden_states = torch.utils.checkpoint.checkpoint(
1146
+ create_custom_forward(resnet),
1147
+ hidden_states,
1148
+ temb,
1149
+ **ckpt_kwargs,
1150
+ )
1151
+ if temp_conv is not None:
1152
+ hidden_states = torch.utils.checkpoint.checkpoint(
1153
+ create_custom_forward(temp_conv),
1154
+ hidden_states,
1155
+ num_frames,
1156
+ sample_index,
1157
+ vision_conditon_frames_sample_index,
1158
+ femb,
1159
+ **ckpt_kwargs,
1160
+ )
1161
+ hidden_states = torch.utils.checkpoint.checkpoint(
1162
+ create_custom_forward(attn, return_dict=False),
1163
+ hidden_states,
1164
+ encoder_hidden_states,
1165
+ None, # timestep
1166
+ None, # added_cond_kwargs
1167
+ None, # class_labels
1168
+ cross_attention_kwargs,
1169
+ attention_mask,
1170
+ encoder_attention_mask,
1171
+ refer_self_attn_emb,
1172
+ refer_self_attn_emb_mode,
1173
+ **ckpt_kwargs,
1174
+ )[0]
1175
+ if temp_attn is not None:
1176
+ hidden_states = torch.utils.checkpoint.checkpoint(
1177
+ create_custom_forward(temp_attn, return_dict=False),
1178
+ hidden_states,
1179
+ femb,
1180
+ # None, # encoder_hidden_states,
1181
+ encoder_hidden_states,
1182
+ None, # timestep
1183
+ None, # class_labels
1184
+ num_frames,
1185
+ cross_attention_kwargs,
1186
+ sample_index,
1187
+ vision_conditon_frames_sample_index,
1188
+ spatial_position_emb,
1189
+ **ckpt_kwargs,
1190
+ )[0]
1191
+ else:
1192
+ hidden_states = resnet(hidden_states, temb)
1193
+ if temp_conv is not None:
1194
+ hidden_states = temp_conv(
1195
+ hidden_states,
1196
+ num_frames=num_frames,
1197
+ femb=femb,
1198
+ sample_index=sample_index,
1199
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
1200
+ )
1201
+ hidden_states = attn(
1202
+ hidden_states,
1203
+ encoder_hidden_states=encoder_hidden_states,
1204
+ cross_attention_kwargs=cross_attention_kwargs,
1205
+ self_attn_block_embs=refer_self_attn_emb,
1206
+ self_attn_block_embs_mode=refer_self_attn_emb_mode,
1207
+ ).sample
1208
+ if temp_attn is not None:
1209
+ hidden_states = temp_attn(
1210
+ hidden_states,
1211
+ femb=femb,
1212
+ num_frames=num_frames,
1213
+ cross_attention_kwargs=cross_attention_kwargs,
1214
+ encoder_hidden_states=encoder_hidden_states,
1215
+ sample_index=sample_index,
1216
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
1217
+ spatial_position_emb=spatial_position_emb,
1218
+ ).sample
1219
+ if (
1220
+ self.need_adain_temporal_cond
1221
+ and num_frames > 1
1222
+ and sample_index is not None
1223
+ ):
1224
+ if self.print_idx == 0:
1225
+ logger.debug(f"adain to vision_condition")
1226
+ hidden_states = batch_adain_conditioned_tensor(
1227
+ hidden_states,
1228
+ num_frames=num_frames,
1229
+ need_style_fidelity=False,
1230
+ src_index=sample_index,
1231
+ dst_index=vision_conditon_frames_sample_index,
1232
+ )
1233
+ if self.upsamplers is not None:
1234
+ for upsampler in self.upsamplers:
1235
+ hidden_states = upsampler(hidden_states, upsample_size)
1236
+ if (
1237
+ self.need_adain_temporal_cond
1238
+ and num_frames > 1
1239
+ and sample_index is not None
1240
+ ):
1241
+ if self.print_idx == 0:
1242
+ logger.debug(f"adain to vision_condition")
1243
+ hidden_states = batch_adain_conditioned_tensor(
1244
+ hidden_states,
1245
+ num_frames=num_frames,
1246
+ need_style_fidelity=False,
1247
+ src_index=sample_index,
1248
+ dst_index=vision_conditon_frames_sample_index,
1249
+ )
1250
+ self.print_idx += 1
1251
+ return hidden_states
1252
+
1253
+
1254
+ class UpBlock3D(nn.Module):
1255
+ print_idx = 0
1256
+
1257
+ def __init__(
1258
+ self,
1259
+ in_channels: int,
1260
+ prev_output_channel: int,
1261
+ out_channels: int,
1262
+ temb_channels: int,
1263
+ femb_channels: int,
1264
+ dropout: float = 0.0,
1265
+ num_layers: int = 1,
1266
+ resnet_eps: float = 1e-6,
1267
+ resnet_time_scale_shift: str = "default",
1268
+ resnet_act_fn: str = "swish",
1269
+ resnet_groups: int = 32,
1270
+ resnet_pre_norm: bool = True,
1271
+ output_scale_factor=1.0,
1272
+ add_upsample=True,
1273
+ temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
1274
+ need_adain_temporal_cond: bool = False,
1275
+ resnet_2d_skip_time_act: bool = False,
1276
+ ):
1277
+ super().__init__()
1278
+ resnets = []
1279
+ temp_convs = []
1280
+
1281
+ for i in range(num_layers):
1282
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1283
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1284
+
1285
+ resnets.append(
1286
+ ResnetBlock2D(
1287
+ in_channels=resnet_in_channels + res_skip_channels,
1288
+ out_channels=out_channels,
1289
+ temb_channels=temb_channels,
1290
+ eps=resnet_eps,
1291
+ groups=resnet_groups,
1292
+ dropout=dropout,
1293
+ time_embedding_norm=resnet_time_scale_shift,
1294
+ non_linearity=resnet_act_fn,
1295
+ output_scale_factor=output_scale_factor,
1296
+ pre_norm=resnet_pre_norm,
1297
+ skip_time_act=resnet_2d_skip_time_act,
1298
+ )
1299
+ )
1300
+ if temporal_conv_block is not None:
1301
+ temp_convs.append(
1302
+ temporal_conv_block(
1303
+ out_channels,
1304
+ out_channels,
1305
+ dropout=0.1,
1306
+ femb_channels=femb_channels,
1307
+ )
1308
+ )
1309
+ else:
1310
+ temp_convs.append(None)
1311
+ self.resnets = nn.ModuleList(resnets)
1312
+ self.temp_convs = nn.ModuleList(temp_convs)
1313
+
1314
+ if add_upsample:
1315
+ self.upsamplers = nn.ModuleList(
1316
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1317
+ )
1318
+ else:
1319
+ self.upsamplers = None
1320
+
1321
+ self.gradient_checkpointing = False
1322
+ self.need_adain_temporal_cond = need_adain_temporal_cond
1323
+
1324
+ def forward(
1325
+ self,
1326
+ hidden_states,
1327
+ res_hidden_states_tuple,
1328
+ temb=None,
1329
+ upsample_size=None,
1330
+ num_frames=1,
1331
+ sample_index: torch.LongTensor = None,
1332
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
1333
+ spatial_position_emb: torch.Tensor = None,
1334
+ femb=None,
1335
+ refer_self_attn_emb: List[torch.Tensor] = None,
1336
+ refer_self_attn_emb_mode: Literal["read", "write"] = "read",
1337
+ ):
1338
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
1339
+ # pop res hidden states
1340
+ res_hidden_states = res_hidden_states_tuple[-1]
1341
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1342
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1343
+
1344
+ if self.training and self.gradient_checkpointing:
1345
+
1346
+ def create_custom_forward(module):
1347
+ def custom_forward(*inputs):
1348
+ return module(*inputs)
1349
+
1350
+ return custom_forward
1351
+
1352
+ ckpt_kwargs: Dict[str, Any] = (
1353
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1354
+ )
1355
+ hidden_states = torch.utils.checkpoint.checkpoint(
1356
+ create_custom_forward(resnet),
1357
+ hidden_states,
1358
+ temb,
1359
+ **ckpt_kwargs,
1360
+ )
1361
+ if temp_conv is not None:
1362
+ hidden_states = torch.utils.checkpoint.checkpoint(
1363
+ create_custom_forward(temp_conv),
1364
+ hidden_states,
1365
+ num_frames,
1366
+ sample_index,
1367
+ vision_conditon_frames_sample_index,
1368
+ femb,
1369
+ **ckpt_kwargs,
1370
+ )
1371
+ else:
1372
+ hidden_states = resnet(hidden_states, temb)
1373
+ if temp_conv is not None:
1374
+ hidden_states = temp_conv(
1375
+ hidden_states,
1376
+ num_frames=num_frames,
1377
+ femb=femb,
1378
+ sample_index=sample_index,
1379
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
1380
+ )
1381
+ if (
1382
+ self.need_adain_temporal_cond
1383
+ and num_frames > 1
1384
+ and sample_index is not None
1385
+ ):
1386
+ if self.print_idx == 0:
1387
+ logger.debug(f"adain to vision_condition")
1388
+ hidden_states = batch_adain_conditioned_tensor(
1389
+ hidden_states,
1390
+ num_frames=num_frames,
1391
+ need_style_fidelity=False,
1392
+ src_index=sample_index,
1393
+ dst_index=vision_conditon_frames_sample_index,
1394
+ )
1395
+ if self.upsamplers is not None:
1396
+ for upsampler in self.upsamplers:
1397
+ hidden_states = upsampler(hidden_states, upsample_size)
1398
+ if (
1399
+ self.need_adain_temporal_cond
1400
+ and num_frames > 1
1401
+ and sample_index is not None
1402
+ ):
1403
+ if self.print_idx == 0:
1404
+ logger.debug(f"adain to vision_condition")
1405
+ hidden_states = batch_adain_conditioned_tensor(
1406
+ hidden_states,
1407
+ num_frames=num_frames,
1408
+ need_style_fidelity=False,
1409
+ src_index=sample_index,
1410
+ dst_index=vision_conditon_frames_sample_index,
1411
+ )
1412
+ self.print_idx += 1
1413
+ return hidden_states
musev/models/unet_3d_condition.py ADDED
@@ -0,0 +1,1740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/unet_3d_condition.py
17
+
18
+ # 1. 增加了from_pretrained,将模型从2D blocks改为3D blocks
19
+ # 1. add from_pretrained, change model from 2D blocks to 3D blocks
20
+
21
+ from copy import deepcopy
22
+ from dataclasses import dataclass
23
+ import inspect
24
+ from pprint import pprint, pformat
25
+ from typing import Any, Dict, List, Optional, Tuple, Union, Literal
26
+ import os
27
+ import logging
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.utils.checkpoint
32
+ from einops import rearrange, repeat
33
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
34
+ from diffusers.loaders import UNet2DConditionLoadersMixin
35
+ from diffusers.utils import BaseOutput
36
+
37
+ # from diffusers.utils import logging
38
+ from diffusers.models.embeddings import (
39
+ TimestepEmbedding,
40
+ Timesteps,
41
+ )
42
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict
43
+ from diffusers import __version__
44
+ from diffusers.utils import (
45
+ CONFIG_NAME,
46
+ DIFFUSERS_CACHE,
47
+ FLAX_WEIGHTS_NAME,
48
+ HF_HUB_OFFLINE,
49
+ SAFETENSORS_WEIGHTS_NAME,
50
+ WEIGHTS_NAME,
51
+ _add_variant,
52
+ _get_model_file,
53
+ is_accelerate_available,
54
+ is_torch_version,
55
+ )
56
+ from diffusers.utils.import_utils import _safetensors_available
57
+ from diffusers.models.unet_3d_condition import (
58
+ UNet3DConditionOutput,
59
+ UNet3DConditionModel as DiffusersUNet3DConditionModel,
60
+ )
61
+ from diffusers.models.attention_processor import (
62
+ Attention,
63
+ AttentionProcessor,
64
+ AttnProcessor,
65
+ AttnProcessor2_0,
66
+ XFormersAttnProcessor,
67
+ )
68
+
69
+ from ..models import Model_Register
70
+
71
+ from .resnet import TemporalConvLayer
72
+ from .temporal_transformer import (
73
+ TransformerTemporalModel,
74
+ )
75
+ from .embeddings import get_2d_sincos_pos_embed, resize_spatial_position_emb
76
+ from .unet_3d_blocks import (
77
+ CrossAttnDownBlock3D,
78
+ CrossAttnUpBlock3D,
79
+ DownBlock3D,
80
+ UNetMidBlock3DCrossAttn,
81
+ UpBlock3D,
82
+ get_down_block,
83
+ get_up_block,
84
+ )
85
+ from ..data.data_util import (
86
+ adaptive_instance_normalization,
87
+ align_repeat_tensor_single_dim,
88
+ batch_adain_conditioned_tensor,
89
+ batch_concat_two_tensor_with_index,
90
+ concat_two_tensor,
91
+ concat_two_tensor_with_index,
92
+ )
93
+ from .attention_processor import BaseIPAttnProcessor
94
+ from .attention_processor import ReferEmbFuseAttention
95
+ from .transformer_2d import Transformer2DModel
96
+ from .attention import BasicTransformerBlock
97
+
98
+
99
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
100
+
101
+ # if is_torch_version(">=", "1.9.0"):
102
+ # _LOW_CPU_MEM_USAGE_DEFAULT = True
103
+ # else:
104
+ # _LOW_CPU_MEM_USAGE_DEFAULT = False
105
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
106
+
107
+ if is_accelerate_available():
108
+ import accelerate
109
+ from accelerate.utils import set_module_tensor_to_device
110
+ from accelerate.utils.versions import is_torch_version
111
+
112
+
113
+ import safetensors
114
+
115
+
116
+ def hack_t2i_sd_layer_attn_with_ip(
117
+ unet: nn.Module,
118
+ self_attn_class: BaseIPAttnProcessor = None,
119
+ cross_attn_class: BaseIPAttnProcessor = None,
120
+ ):
121
+ attn_procs = {}
122
+ for name in unet.attn_processors.keys():
123
+ if "temp_attentions" in name or "transformer_in" in name:
124
+ continue
125
+ if name.endswith("attn1.processor") and self_attn_class is not None:
126
+ attn_procs[name] = self_attn_class()
127
+ if unet.print_idx == 0:
128
+ logger.debug(
129
+ f"hack attn_processor of {name} to {attn_procs[name].__class__.__name__}"
130
+ )
131
+ elif name.endswith("attn2.processor") and cross_attn_class is not None:
132
+ attn_procs[name] = cross_attn_class()
133
+ if unet.print_idx == 0:
134
+ logger.debug(
135
+ f"hack attn_processor of {name} to {attn_procs[name].__class__.__name__}"
136
+ )
137
+ unet.set_attn_processor(attn_procs, strict=False)
138
+
139
+
140
+ def convert_2D_to_3D(
141
+ module_names,
142
+ valid_modules=(
143
+ "CrossAttnDownBlock2D",
144
+ "CrossAttnUpBlock2D",
145
+ "DownBlock2D",
146
+ "UNetMidBlock2DCrossAttn",
147
+ "UpBlock2D",
148
+ ),
149
+ ):
150
+ if not isinstance(module_names, list):
151
+ return module_names.replace("2D", "3D")
152
+
153
+ return_modules = []
154
+ for module_name in module_names:
155
+ if module_name in valid_modules:
156
+ return_modules.append(module_name.replace("2D", "3D"))
157
+ else:
158
+ return_modules.append(module_name)
159
+ return return_modules
160
+
161
+
162
+ def insert_spatial_self_attn_idx(unet):
163
+ pass
164
+
165
+
166
+ @dataclass
167
+ class UNet3DConditionOutput(BaseOutput):
168
+ """
169
+ The output of [`UNet3DConditionModel`].
170
+
171
+ Args:
172
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
173
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
174
+ """
175
+
176
+ sample: torch.FloatTensor
177
+
178
+
179
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
180
+ r"""
181
+ UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
182
+ and returns sample shaped output.
183
+
184
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
185
+ implements for all the models (such as downloading or saving, etc.)
186
+
187
+ Parameters:
188
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
189
+ Height and width of input/output sample.
190
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
191
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
192
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
193
+ The tuple of downsample blocks to use.
194
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
195
+ The tuple of upsample blocks to use.
196
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
197
+ The tuple of output channels for each block.
198
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
199
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
200
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
201
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
202
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
203
+ If `None`, it will skip the normalization and activation layers in post-processing
204
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
205
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
206
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
207
+ """
208
+
209
+ _supports_gradient_checkpointing = True
210
+ print_idx = 0
211
+
212
+ @register_to_config
213
+ def __init__(
214
+ self,
215
+ sample_size: Optional[int] = None,
216
+ in_channels: int = 4,
217
+ out_channels: int = 4,
218
+ down_block_types: Tuple[str] = (
219
+ "CrossAttnDownBlock3D",
220
+ "CrossAttnDownBlock3D",
221
+ "CrossAttnDownBlock3D",
222
+ "DownBlock3D",
223
+ ),
224
+ up_block_types: Tuple[str] = (
225
+ "UpBlock3D",
226
+ "CrossAttnUpBlock3D",
227
+ "CrossAttnUpBlock3D",
228
+ "CrossAttnUpBlock3D",
229
+ ),
230
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
231
+ layers_per_block: int = 2,
232
+ downsample_padding: int = 1,
233
+ mid_block_scale_factor: float = 1,
234
+ act_fn: str = "silu",
235
+ norm_num_groups: Optional[int] = 32,
236
+ norm_eps: float = 1e-5,
237
+ cross_attention_dim: int = 1024,
238
+ attention_head_dim: Union[int, Tuple[int]] = 8,
239
+ temporal_conv_block: str = "TemporalConvLayer",
240
+ temporal_transformer: str = "TransformerTemporalModel",
241
+ need_spatial_position_emb: bool = False,
242
+ need_transformer_in: bool = True,
243
+ need_t2i_ip_adapter: bool = False, # self_attn, t2i.attn1
244
+ need_adain_temporal_cond: bool = False,
245
+ t2i_ip_adapter_attn_processor: str = "NonParamT2ISelfReferenceXFormersAttnProcessor",
246
+ keep_vision_condtion: bool = False,
247
+ use_anivv1_cfg: bool = False,
248
+ resnet_2d_skip_time_act: bool = False,
249
+ need_zero_vis_cond_temb: bool = True,
250
+ norm_spatial_length: bool = False,
251
+ spatial_max_length: int = 2048,
252
+ need_refer_emb: bool = False,
253
+ ip_adapter_cross_attn: bool = False, # cross_attn, t2i.attn2
254
+ t2i_crossattn_ip_adapter_attn_processor: str = "T2IReferencenetIPAdapterXFormersAttnProcessor",
255
+ need_t2i_facein: bool = False,
256
+ need_t2i_ip_adapter_face: bool = False,
257
+ need_vis_cond_mask: bool = False,
258
+ ):
259
+ """_summary_
260
+
261
+ Args:
262
+ sample_size (Optional[int], optional): _description_. Defaults to None.
263
+ in_channels (int, optional): _description_. Defaults to 4.
264
+ out_channels (int, optional): _description_. Defaults to 4.
265
+ down_block_types (Tuple[str], optional): _description_. Defaults to ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ).
266
+ up_block_types (Tuple[str], optional): _description_. Defaults to ( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", ).
267
+ block_out_channels (Tuple[int], optional): _description_. Defaults to (320, 640, 1280, 1280).
268
+ layers_per_block (int, optional): _description_. Defaults to 2.
269
+ downsample_padding (int, optional): _description_. Defaults to 1.
270
+ mid_block_scale_factor (float, optional): _description_. Defaults to 1.
271
+ act_fn (str, optional): _description_. Defaults to "silu".
272
+ norm_num_groups (Optional[int], optional): _description_. Defaults to 32.
273
+ norm_eps (float, optional): _description_. Defaults to 1e-5.
274
+ cross_attention_dim (int, optional): _description_. Defaults to 1024.
275
+ attention_head_dim (Union[int, Tuple[int]], optional): _description_. Defaults to 8.
276
+ temporal_conv_block (str, optional): 3D卷积字符串,需要注册在 Model_Register. Defaults to "TemporalConvLayer".
277
+ temporal_transformer (str, optional): 时序 Transformer block字符串,需要定义在 Model_Register. Defaults to "TransformerTemporalModel".
278
+ need_spatial_position_emb (bool, optional): 是否需要 spatial hw 的emb,需要配合 thw attn使用. Defaults to False.
279
+ need_transformer_in (bool, optional): 是否需要 第一个 temporal_transformer_block. Defaults to True.
280
+ need_t2i_ip_adapter (bool, optional): T2I 模块是否需要面向视觉条件帧的 attn. Defaults to False.
281
+ need_adain_temporal_cond (bool, optional): 是否需要面向首帧 使用Adain. Defaults to False.
282
+ t2i_ip_adapter_attn_processor (str, optional):
283
+ t2i attn_processor的优化版,需配合need_t2i_ip_adapter使用,
284
+ 有 NonParam 表示无参ReferenceOnly-attn,没有表示有参 IpAdapter.
285
+ Defaults to "NonParamT2ISelfReferenceXFormersAttnProcessor".
286
+ keep_vision_condtion (bool, optional): 是否对视觉条件帧不加 timestep emb. Defaults to False.
287
+ use_anivv1_cfg (bool, optional): 一些基本配置 是否延续AnivV设计. Defaults to False.
288
+ resnet_2d_skip_time_act (bool, optional): 配合use_anivv1_cfg,修改 transformer 2d block. Defaults to False.
289
+ need_zero_vis_cond_temb (bool, optional): 目前无效参数. Defaults to True.
290
+ norm_spatial_length (bool, optional): 是否需要 norm_spatial_length,只有当 need_spatial_position_emb= True时,才有效. Defaults to False.
291
+ spatial_max_length (int, optional): 归一化长度. Defaults to 2048.
292
+
293
+ Raises:
294
+ ValueError: _description_
295
+ ValueError: _description_
296
+ ValueError: _description_
297
+ """
298
+ super(UNet3DConditionModel, self).__init__()
299
+ self.keep_vision_condtion = keep_vision_condtion
300
+ self.use_anivv1_cfg = use_anivv1_cfg
301
+ self.sample_size = sample_size
302
+ self.resnet_2d_skip_time_act = resnet_2d_skip_time_act
303
+ self.need_zero_vis_cond_temb = need_zero_vis_cond_temb
304
+ self.norm_spatial_length = norm_spatial_length
305
+ self.spatial_max_length = spatial_max_length
306
+ self.need_refer_emb = need_refer_emb
307
+ self.ip_adapter_cross_attn = ip_adapter_cross_attn
308
+ self.need_t2i_facein = need_t2i_facein
309
+ self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
310
+
311
+ logger.debug(f"need_t2i_ip_adapter_face={need_t2i_ip_adapter_face}")
312
+ # Check inputs
313
+ if len(down_block_types) != len(up_block_types):
314
+ raise ValueError(
315
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
316
+ )
317
+
318
+ if len(block_out_channels) != len(down_block_types):
319
+ raise ValueError(
320
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
321
+ )
322
+
323
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
324
+ down_block_types
325
+ ):
326
+ raise ValueError(
327
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
328
+ )
329
+
330
+ # input
331
+ conv_in_kernel = 3
332
+ conv_out_kernel = 3
333
+ conv_in_padding = (conv_in_kernel - 1) // 2
334
+ self.conv_in = nn.Conv2d(
335
+ in_channels,
336
+ block_out_channels[0],
337
+ kernel_size=conv_in_kernel,
338
+ padding=conv_in_padding,
339
+ )
340
+
341
+ # time
342
+ time_embed_dim = block_out_channels[0] * 4
343
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
344
+ timestep_input_dim = block_out_channels[0]
345
+
346
+ self.time_embedding = TimestepEmbedding(
347
+ timestep_input_dim,
348
+ time_embed_dim,
349
+ act_fn=act_fn,
350
+ )
351
+ if use_anivv1_cfg:
352
+ self.time_nonlinearity = nn.SiLU()
353
+
354
+ # frame
355
+ frame_embed_dim = block_out_channels[0] * 4
356
+ self.frame_proj = Timesteps(block_out_channels[0], True, 0)
357
+ frame_input_dim = block_out_channels[0]
358
+ if temporal_transformer is not None:
359
+ self.frame_embedding = TimestepEmbedding(
360
+ frame_input_dim,
361
+ frame_embed_dim,
362
+ act_fn=act_fn,
363
+ )
364
+ else:
365
+ self.frame_embedding = None
366
+ if use_anivv1_cfg:
367
+ self.femb_nonlinearity = nn.SiLU()
368
+
369
+ # spatial_position_emb
370
+ self.need_spatial_position_emb = need_spatial_position_emb
371
+ if need_spatial_position_emb:
372
+ self.spatial_position_input_dim = block_out_channels[0] * 2
373
+ self.spatial_position_embed_dim = block_out_channels[0] * 4
374
+
375
+ self.spatial_position_embedding = TimestepEmbedding(
376
+ self.spatial_position_input_dim,
377
+ self.spatial_position_embed_dim,
378
+ act_fn=act_fn,
379
+ )
380
+
381
+ # 从模型注册表中获取 模型类
382
+ temporal_conv_block = (
383
+ Model_Register[temporal_conv_block]
384
+ if isinstance(temporal_conv_block, str)
385
+ and temporal_conv_block.lower() != "none"
386
+ else None
387
+ )
388
+ self.need_transformer_in = need_transformer_in
389
+
390
+ temporal_transformer = (
391
+ Model_Register[temporal_transformer]
392
+ if isinstance(temporal_transformer, str)
393
+ and temporal_transformer.lower() != "none"
394
+ else None
395
+ )
396
+ self.need_vis_cond_mask = need_vis_cond_mask
397
+
398
+ if need_transformer_in and temporal_transformer is not None:
399
+ self.transformer_in = temporal_transformer(
400
+ num_attention_heads=attention_head_dim,
401
+ attention_head_dim=block_out_channels[0] // attention_head_dim,
402
+ in_channels=block_out_channels[0],
403
+ num_layers=1,
404
+ femb_channels=frame_embed_dim,
405
+ need_spatial_position_emb=need_spatial_position_emb,
406
+ cross_attention_dim=cross_attention_dim,
407
+ )
408
+
409
+ # class embedding
410
+ self.down_blocks = nn.ModuleList([])
411
+ self.up_blocks = nn.ModuleList([])
412
+
413
+ if isinstance(attention_head_dim, int):
414
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
415
+
416
+ self.need_t2i_ip_adapter = need_t2i_ip_adapter
417
+ # 确定T2I Attn 是否加入 ReferenceOnly机制或Ipadaper机制
418
+ # TODO:有待更好的实现机制,
419
+ need_t2i_ip_adapter_param = (
420
+ t2i_ip_adapter_attn_processor is not None
421
+ and "NonParam" not in t2i_ip_adapter_attn_processor
422
+ and need_t2i_ip_adapter
423
+ )
424
+ self.need_adain_temporal_cond = need_adain_temporal_cond
425
+ self.t2i_ip_adapter_attn_processor = t2i_ip_adapter_attn_processor
426
+
427
+ if need_refer_emb:
428
+ self.first_refer_emb_attns = ReferEmbFuseAttention(
429
+ query_dim=block_out_channels[0],
430
+ heads=attention_head_dim[0],
431
+ dim_head=block_out_channels[0] // attention_head_dim[0],
432
+ dropout=0,
433
+ bias=False,
434
+ cross_attention_dim=None,
435
+ upcast_attention=False,
436
+ )
437
+ self.mid_block_refer_emb_attns = ReferEmbFuseAttention(
438
+ query_dim=block_out_channels[-1],
439
+ heads=attention_head_dim[-1],
440
+ dim_head=block_out_channels[-1] // attention_head_dim[-1],
441
+ dropout=0,
442
+ bias=False,
443
+ cross_attention_dim=None,
444
+ upcast_attention=False,
445
+ )
446
+ else:
447
+ self.first_refer_emb_attns = None
448
+ self.mid_block_refer_emb_attns = None
449
+ # down
450
+ output_channel = block_out_channels[0]
451
+ self.layers_per_block = layers_per_block
452
+ self.block_out_channels = block_out_channels
453
+ for i, down_block_type in enumerate(down_block_types):
454
+ input_channel = output_channel
455
+ output_channel = block_out_channels[i]
456
+ is_final_block = i == len(block_out_channels) - 1
457
+
458
+ down_block = get_down_block(
459
+ down_block_type,
460
+ num_layers=layers_per_block,
461
+ in_channels=input_channel,
462
+ out_channels=output_channel,
463
+ temb_channels=time_embed_dim,
464
+ femb_channels=frame_embed_dim,
465
+ add_downsample=not is_final_block,
466
+ resnet_eps=norm_eps,
467
+ resnet_act_fn=act_fn,
468
+ resnet_groups=norm_num_groups,
469
+ cross_attention_dim=cross_attention_dim,
470
+ attn_num_head_channels=attention_head_dim[i],
471
+ downsample_padding=downsample_padding,
472
+ dual_cross_attention=False,
473
+ temporal_conv_block=temporal_conv_block,
474
+ temporal_transformer=temporal_transformer,
475
+ need_spatial_position_emb=need_spatial_position_emb,
476
+ need_t2i_ip_adapter=need_t2i_ip_adapter_param,
477
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
478
+ need_t2i_facein=need_t2i_facein,
479
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
480
+ need_adain_temporal_cond=need_adain_temporal_cond,
481
+ resnet_2d_skip_time_act=resnet_2d_skip_time_act,
482
+ need_refer_emb=need_refer_emb,
483
+ )
484
+ self.down_blocks.append(down_block)
485
+ # mid
486
+ self.mid_block = UNetMidBlock3DCrossAttn(
487
+ in_channels=block_out_channels[-1],
488
+ temb_channels=time_embed_dim,
489
+ femb_channels=frame_embed_dim,
490
+ resnet_eps=norm_eps,
491
+ resnet_act_fn=act_fn,
492
+ output_scale_factor=mid_block_scale_factor,
493
+ cross_attention_dim=cross_attention_dim,
494
+ attn_num_head_channels=attention_head_dim[-1],
495
+ resnet_groups=norm_num_groups,
496
+ dual_cross_attention=False,
497
+ temporal_conv_block=temporal_conv_block,
498
+ temporal_transformer=temporal_transformer,
499
+ need_spatial_position_emb=need_spatial_position_emb,
500
+ need_t2i_ip_adapter=need_t2i_ip_adapter_param,
501
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
502
+ need_t2i_facein=need_t2i_facein,
503
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
504
+ need_adain_temporal_cond=need_adain_temporal_cond,
505
+ resnet_2d_skip_time_act=resnet_2d_skip_time_act,
506
+ )
507
+
508
+ # count how many layers upsample the images
509
+ self.num_upsamplers = 0
510
+
511
+ # up
512
+ reversed_block_out_channels = list(reversed(block_out_channels))
513
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
514
+
515
+ output_channel = reversed_block_out_channels[0]
516
+ for i, up_block_type in enumerate(up_block_types):
517
+ is_final_block = i == len(block_out_channels) - 1
518
+
519
+ prev_output_channel = output_channel
520
+ output_channel = reversed_block_out_channels[i]
521
+ input_channel = reversed_block_out_channels[
522
+ min(i + 1, len(block_out_channels) - 1)
523
+ ]
524
+
525
+ # add upsample block for all BUT final layer
526
+ if not is_final_block:
527
+ add_upsample = True
528
+ self.num_upsamplers += 1
529
+ else:
530
+ add_upsample = False
531
+
532
+ up_block = get_up_block(
533
+ up_block_type,
534
+ num_layers=layers_per_block + 1,
535
+ in_channels=input_channel,
536
+ out_channels=output_channel,
537
+ prev_output_channel=prev_output_channel,
538
+ temb_channels=time_embed_dim,
539
+ femb_channels=frame_embed_dim,
540
+ add_upsample=add_upsample,
541
+ resnet_eps=norm_eps,
542
+ resnet_act_fn=act_fn,
543
+ resnet_groups=norm_num_groups,
544
+ cross_attention_dim=cross_attention_dim,
545
+ attn_num_head_channels=reversed_attention_head_dim[i],
546
+ dual_cross_attention=False,
547
+ temporal_conv_block=temporal_conv_block,
548
+ temporal_transformer=temporal_transformer,
549
+ need_spatial_position_emb=need_spatial_position_emb,
550
+ need_t2i_ip_adapter=need_t2i_ip_adapter_param,
551
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
552
+ need_t2i_facein=need_t2i_facein,
553
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
554
+ need_adain_temporal_cond=need_adain_temporal_cond,
555
+ resnet_2d_skip_time_act=resnet_2d_skip_time_act,
556
+ )
557
+ self.up_blocks.append(up_block)
558
+ prev_output_channel = output_channel
559
+
560
+ # out
561
+ if norm_num_groups is not None:
562
+ self.conv_norm_out = nn.GroupNorm(
563
+ num_channels=block_out_channels[0],
564
+ num_groups=norm_num_groups,
565
+ eps=norm_eps,
566
+ )
567
+ self.conv_act = nn.SiLU()
568
+ else:
569
+ self.conv_norm_out = None
570
+ self.conv_act = None
571
+
572
+ conv_out_padding = (conv_out_kernel - 1) // 2
573
+ self.conv_out = nn.Conv2d(
574
+ block_out_channels[0],
575
+ out_channels,
576
+ kernel_size=conv_out_kernel,
577
+ padding=conv_out_padding,
578
+ )
579
+ self.insert_spatial_self_attn_idx()
580
+
581
+ # 根据需要hack attn_processor,实现ip_adapter等功能
582
+ if need_t2i_ip_adapter or ip_adapter_cross_attn:
583
+ hack_t2i_sd_layer_attn_with_ip(
584
+ self,
585
+ self_attn_class=Model_Register[t2i_ip_adapter_attn_processor]
586
+ if t2i_ip_adapter_attn_processor is not None and need_t2i_ip_adapter
587
+ else None,
588
+ cross_attn_class=Model_Register[t2i_crossattn_ip_adapter_attn_processor]
589
+ if t2i_crossattn_ip_adapter_attn_processor is not None
590
+ and (
591
+ ip_adapter_cross_attn or need_t2i_facein or need_t2i_ip_adapter_face
592
+ )
593
+ else None,
594
+ )
595
+ # logger.debug(pformat(self.attn_processors))
596
+
597
+ # 非参数AttnProcessor,就不需要to_k_ip、to_v_ip参数了
598
+ if (
599
+ t2i_ip_adapter_attn_processor is None
600
+ or "NonParam" in t2i_ip_adapter_attn_processor
601
+ ):
602
+ need_t2i_ip_adapter = False
603
+
604
+ if self.print_idx == 0:
605
+ logger.debug("Unet3Model Parameters")
606
+ # logger.debug(pformat(self.__dict__))
607
+
608
+ # 会在 set_skip_temporal_layers 设置 skip_refer_downblock_emb
609
+ # 当为 True 时,会跳过 referencenet_block_emb的影响,主要用于首帧生成
610
+ self.skip_refer_downblock_emb = False
611
+
612
+ @property
613
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
614
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
615
+ r"""
616
+ Returns:
617
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
618
+ indexed by its weight name.
619
+ """
620
+ # set recursively
621
+ processors = {}
622
+
623
+ def fn_recursive_add_processors(
624
+ name: str,
625
+ module: torch.nn.Module,
626
+ processors: Dict[str, AttentionProcessor],
627
+ ):
628
+ if hasattr(module, "set_processor"):
629
+ processors[f"{name}.processor"] = module.processor
630
+
631
+ for sub_name, child in module.named_children():
632
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
633
+
634
+ return processors
635
+
636
+ for name, module in self.named_children():
637
+ fn_recursive_add_processors(name, module, processors)
638
+
639
+ return processors
640
+
641
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
642
+ def set_attention_slice(self, slice_size):
643
+ r"""
644
+ Enable sliced attention computation.
645
+
646
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
647
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
648
+
649
+ Args:
650
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
651
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
652
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
653
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
654
+ must be a multiple of `slice_size`.
655
+ """
656
+ sliceable_head_dims = []
657
+
658
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
659
+ if hasattr(module, "set_attention_slice"):
660
+ sliceable_head_dims.append(module.sliceable_head_dim)
661
+
662
+ for child in module.children():
663
+ fn_recursive_retrieve_sliceable_dims(child)
664
+
665
+ # retrieve number of attention layers
666
+ for module in self.children():
667
+ fn_recursive_retrieve_sliceable_dims(module)
668
+
669
+ num_sliceable_layers = len(sliceable_head_dims)
670
+
671
+ if slice_size == "auto":
672
+ # half the attention head size is usually a good trade-off between
673
+ # speed and memory
674
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
675
+ elif slice_size == "max":
676
+ # make smallest slice possible
677
+ slice_size = num_sliceable_layers * [1]
678
+
679
+ slice_size = (
680
+ num_sliceable_layers * [slice_size]
681
+ if not isinstance(slice_size, list)
682
+ else slice_size
683
+ )
684
+
685
+ if len(slice_size) != len(sliceable_head_dims):
686
+ raise ValueError(
687
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
688
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
689
+ )
690
+
691
+ for i in range(len(slice_size)):
692
+ size = slice_size[i]
693
+ dim = sliceable_head_dims[i]
694
+ if size is not None and size > dim:
695
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
696
+
697
+ # Recursively walk through all the children.
698
+ # Any children which exposes the set_attention_slice method
699
+ # gets the message
700
+ def fn_recursive_set_attention_slice(
701
+ module: torch.nn.Module, slice_size: List[int]
702
+ ):
703
+ if hasattr(module, "set_attention_slice"):
704
+ module.set_attention_slice(slice_size.pop())
705
+
706
+ for child in module.children():
707
+ fn_recursive_set_attention_slice(child, slice_size)
708
+
709
+ reversed_slice_size = list(reversed(slice_size))
710
+ for module in self.children():
711
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
712
+
713
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
714
+ def set_attn_processor(
715
+ self,
716
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
717
+ strict: bool = True,
718
+ ):
719
+ r"""
720
+ Parameters:
721
+ `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
722
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
723
+ of **all** `Attention` layers.
724
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
725
+
726
+ """
727
+ count = len(self.attn_processors.keys())
728
+
729
+ if isinstance(processor, dict) and len(processor) != count and strict:
730
+ raise ValueError(
731
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
732
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
733
+ )
734
+
735
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
736
+ if hasattr(module, "set_processor"):
737
+ if not isinstance(processor, dict):
738
+ logger.debug(
739
+ f"module {name} set attn processor {processor.__class__.__name__}"
740
+ )
741
+ module.set_processor(processor)
742
+ else:
743
+ if f"{name}.processor" in processor:
744
+ logger.debug(
745
+ "module {} set attn processor {}".format(
746
+ name, processor[f"{name}.processor"].__class__.__name__
747
+ )
748
+ )
749
+ module.set_processor(processor.pop(f"{name}.processor"))
750
+ else:
751
+ logger.debug(
752
+ f"module {name} has no new target attn_processor, still use {module.processor.__class__.__name__} "
753
+ )
754
+ for sub_name, child in module.named_children():
755
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
756
+
757
+ for name, module in self.named_children():
758
+ fn_recursive_attn_processor(name, module, processor)
759
+
760
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
761
+ def set_default_attn_processor(self):
762
+ """
763
+ Disables custom attention processors and sets the default attention implementation.
764
+ """
765
+ self.set_attn_processor(AttnProcessor())
766
+
767
+ def _set_gradient_checkpointing(self, module, value=False):
768
+ if isinstance(
769
+ module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)
770
+ ):
771
+ module.gradient_checkpointing = value
772
+
773
+ def forward(
774
+ self,
775
+ sample: torch.FloatTensor,
776
+ timestep: Union[torch.Tensor, float, int],
777
+ encoder_hidden_states: torch.Tensor,
778
+ class_labels: Optional[torch.Tensor] = None,
779
+ timestep_cond: Optional[torch.Tensor] = None,
780
+ attention_mask: Optional[torch.Tensor] = None,
781
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
782
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
783
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
784
+ return_dict: bool = True,
785
+ sample_index: torch.LongTensor = None,
786
+ vision_condition_frames_sample: torch.Tensor = None,
787
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
788
+ sample_frame_rate: int = 10,
789
+ skip_temporal_layers: bool = None,
790
+ frame_index: torch.LongTensor = None,
791
+ down_block_refer_embs: Optional[Tuple[torch.Tensor]] = None,
792
+ mid_block_refer_emb: Optional[torch.Tensor] = None,
793
+ refer_self_attn_emb: Optional[List[torch.Tensor]] = None,
794
+ refer_self_attn_emb_mode: Literal["read", "write"] = "read",
795
+ vision_clip_emb: torch.Tensor = None,
796
+ ip_adapter_scale: float = 1.0,
797
+ face_emb: torch.Tensor = None,
798
+ facein_scale: float = 1.0,
799
+ ip_adapter_face_emb: torch.Tensor = None,
800
+ ip_adapter_face_scale: float = 1.0,
801
+ do_classifier_free_guidance: bool = False,
802
+ pose_guider_emb: torch.Tensor = None,
803
+ ) -> Union[UNet3DConditionOutput, Tuple]:
804
+ """_summary_
805
+
806
+ Args:
807
+ sample (torch.FloatTensor): _description_
808
+ timestep (Union[torch.Tensor, float, int]): _description_
809
+ encoder_hidden_states (torch.Tensor): _description_
810
+ class_labels (Optional[torch.Tensor], optional): _description_. Defaults to None.
811
+ timestep_cond (Optional[torch.Tensor], optional): _description_. Defaults to None.
812
+ attention_mask (Optional[torch.Tensor], optional): _description_. Defaults to None.
813
+ cross_attention_kwargs (Optional[Dict[str, Any]], optional): _description_. Defaults to None.
814
+ down_block_additional_residuals (Optional[Tuple[torch.Tensor]], optional): _description_. Defaults to None.
815
+ mid_block_additional_residual (Optional[torch.Tensor], optional): _description_. Defaults to None.
816
+ return_dict (bool, optional): _description_. Defaults to True.
817
+ sample_index (torch.LongTensor, optional): _description_. Defaults to None.
818
+ vision_condition_frames_sample (torch.Tensor, optional): _description_. Defaults to None.
819
+ vision_conditon_frames_sample_index (torch.LongTensor, optional): _description_. Defaults to None.
820
+ sample_frame_rate (int, optional): _description_. Defaults to 10.
821
+ skip_temporal_layers (bool, optional): _description_. Defaults to None.
822
+ frame_index (torch.LongTensor, optional): _description_. Defaults to None.
823
+ up_block_additional_residual (Optional[torch.Tensor], optional): 用于up_block的 参考latent. Defaults to None.
824
+ down_block_refer_embs (Optional[torch.Tensor], optional): 用于 download 的 参考latent. Defaults to None.
825
+ how_fuse_referencenet_emb (Literal, optional): 如何融合 参考 latent. Defaults to ["add", "attn"]="add".
826
+ add: 要求 additional_latent 和 latent hw 同尺寸. hw of addtional_latent should be same as of latent
827
+ attn: concat bt*h1w1*c and bt*h2w2*c into bt*(h1w1+h2w2)*c, and then as key,value into attn
828
+ Raises:
829
+ ValueError: _description_
830
+
831
+ Returns:
832
+ Union[UNet3DConditionOutput, Tuple]: _description_
833
+ """
834
+
835
+ if skip_temporal_layers is not None:
836
+ self.set_skip_temporal_layers(skip_temporal_layers)
837
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
838
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
839
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
840
+ # on the fly if necessary.
841
+ default_overall_up_factor = 2**self.num_upsamplers
842
+
843
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
844
+ forward_upsample_size = False
845
+ upsample_size = None
846
+
847
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
848
+ # logger.debug("Forward upsample size to force interpolation output size.")
849
+ forward_upsample_size = True
850
+
851
+ # prepare attention_mask
852
+ if attention_mask is not None:
853
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
854
+ attention_mask = attention_mask.unsqueeze(1)
855
+
856
+ # 1. time
857
+ timesteps = timestep
858
+ if not torch.is_tensor(timesteps):
859
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
860
+ # This would be a good case for the `match` statement (Python 3.10+)
861
+ is_mps = sample.device.type == "mps"
862
+ if isinstance(timestep, float):
863
+ dtype = torch.float32 if is_mps else torch.float64
864
+ else:
865
+ dtype = torch.int32 if is_mps else torch.int64
866
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
867
+ elif len(timesteps.shape) == 0:
868
+ timesteps = timesteps[None].to(sample.device)
869
+
870
+ batch_size = sample.shape[0]
871
+
872
+ # when vision_condition_frames_sample is not None and vision_conditon_frames_sample_index is not None
873
+ # if not None, b c t h w -> b c (t + n_content ) h w
874
+
875
+ if vision_condition_frames_sample is not None:
876
+ sample = batch_concat_two_tensor_with_index(
877
+ sample,
878
+ sample_index,
879
+ vision_condition_frames_sample,
880
+ vision_conditon_frames_sample_index,
881
+ dim=2,
882
+ )
883
+
884
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
885
+ batch_size, channel, num_frames, height, width = sample.shape
886
+
887
+ # 准备 timestep emb
888
+ timesteps = timesteps.expand(sample.shape[0])
889
+ temb = self.time_proj(timesteps)
890
+ temb = temb.to(dtype=self.dtype)
891
+ emb = self.time_embedding(temb, timestep_cond)
892
+ if self.use_anivv1_cfg:
893
+ emb = self.time_nonlinearity(emb)
894
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
895
+
896
+ # 一致性保持,使条件时序帧的 首帧 timesteps emb 为 0,即不影响视觉条件帧
897
+ # keep consistent with the first frame of vision condition frames
898
+ if (
899
+ self.keep_vision_condtion
900
+ and num_frames > 1
901
+ and sample_index is not None
902
+ and vision_conditon_frames_sample_index is not None
903
+ ):
904
+ emb = rearrange(emb, "(b t) d -> b t d", t=num_frames)
905
+ emb[:, vision_conditon_frames_sample_index, :] = 0
906
+ emb = rearrange(emb, "b t d->(b t) d")
907
+
908
+ # temporal positional embedding
909
+ femb = None
910
+ if self.temporal_transformer is not None:
911
+ if frame_index is None:
912
+ frame_index = torch.arange(
913
+ num_frames, dtype=torch.long, device=sample.device
914
+ )
915
+ if self.use_anivv1_cfg:
916
+ frame_index = (frame_index * sample_frame_rate).to(dtype=torch.long)
917
+ femb = self.frame_proj(frame_index)
918
+ if self.print_idx == 0:
919
+ logger.debug(
920
+ f"unet prepare frame_index, {femb.shape}, {batch_size}"
921
+ )
922
+ femb = repeat(femb, "t d-> b t d", b=batch_size)
923
+ else:
924
+ # b t -> b t d
925
+ assert frame_index.ndim == 2, ValueError(
926
+ "ndim of given frame_index should be 2, but {frame_index.ndim}"
927
+ )
928
+ femb = torch.stack(
929
+ [self.frame_proj(frame_index[i]) for i in range(batch_size)], dim=0
930
+ )
931
+ if self.temporal_transformer is not None:
932
+ femb = femb.to(dtype=self.dtype)
933
+ femb = self.frame_embedding(
934
+ femb,
935
+ )
936
+ if self.use_anivv1_cfg:
937
+ femb = self.femb_nonlinearity(femb)
938
+ if encoder_hidden_states.ndim == 3:
939
+ encoder_hidden_states = align_repeat_tensor_single_dim(
940
+ encoder_hidden_states, target_length=emb.shape[0], dim=0
941
+ )
942
+ elif encoder_hidden_states.ndim == 4:
943
+ encoder_hidden_states = rearrange(
944
+ encoder_hidden_states, "b t n q-> (b t) n q"
945
+ )
946
+ else:
947
+ raise ValueError(
948
+ f"only support ndim in [3, 4], but given {encoder_hidden_states.ndim}"
949
+ )
950
+ if vision_clip_emb is not None:
951
+ if vision_clip_emb.ndim == 4:
952
+ vision_clip_emb = rearrange(vision_clip_emb, "b t n q-> (b t) n q")
953
+ # 准备 hw 层面的 spatial positional embedding
954
+ # prepare spatial_position_emb
955
+ if self.need_spatial_position_emb:
956
+ # height * width, self.spatial_position_input_dim
957
+ spatial_position_emb = get_2d_sincos_pos_embed(
958
+ embed_dim=self.spatial_position_input_dim,
959
+ grid_size_w=width,
960
+ grid_size_h=height,
961
+ cls_token=False,
962
+ norm_length=self.norm_spatial_length,
963
+ max_length=self.spatial_max_length,
964
+ )
965
+ spatial_position_emb = torch.from_numpy(spatial_position_emb).to(
966
+ device=sample.device, dtype=self.dtype
967
+ )
968
+ # height * width, self.spatial_position_embed_dim
969
+ spatial_position_emb = self.spatial_position_embedding(spatial_position_emb)
970
+ else:
971
+ spatial_position_emb = None
972
+
973
+ # prepare cross_attention_kwargs,ReferenceOnly/IpAdapter的attn_processor需要这些参数 进行 latenst和viscond_latents拆分运算
974
+ if (
975
+ self.need_t2i_ip_adapter
976
+ or self.ip_adapter_cross_attn
977
+ or self.need_t2i_facein
978
+ or self.need_t2i_ip_adapter_face
979
+ ):
980
+ if cross_attention_kwargs is None:
981
+ cross_attention_kwargs = {}
982
+ cross_attention_kwargs["num_frames"] = num_frames
983
+ cross_attention_kwargs[
984
+ "do_classifier_free_guidance"
985
+ ] = do_classifier_free_guidance
986
+ cross_attention_kwargs["sample_index"] = sample_index
987
+ cross_attention_kwargs[
988
+ "vision_conditon_frames_sample_index"
989
+ ] = vision_conditon_frames_sample_index
990
+ if self.ip_adapter_cross_attn:
991
+ cross_attention_kwargs["vision_clip_emb"] = vision_clip_emb
992
+ cross_attention_kwargs["ip_adapter_scale"] = ip_adapter_scale
993
+ if self.need_t2i_facein:
994
+ if self.print_idx == 0:
995
+ logger.debug(
996
+ f"face_emb={type(face_emb)}, facein_scale={facein_scale}"
997
+ )
998
+ cross_attention_kwargs["face_emb"] = face_emb
999
+ cross_attention_kwargs["facein_scale"] = facein_scale
1000
+ if self.need_t2i_ip_adapter_face:
1001
+ if self.print_idx == 0:
1002
+ logger.debug(
1003
+ f"ip_adapter_face_emb={type(ip_adapter_face_emb)}, ip_adapter_face_scale={ip_adapter_face_scale}"
1004
+ )
1005
+ cross_attention_kwargs["ip_adapter_face_emb"] = ip_adapter_face_emb
1006
+ cross_attention_kwargs["ip_adapter_face_scale"] = ip_adapter_face_scale
1007
+ # 2. pre-process
1008
+ sample = rearrange(sample, "b c t h w -> (b t) c h w")
1009
+ sample = self.conv_in(sample)
1010
+
1011
+ if pose_guider_emb is not None:
1012
+ if self.print_idx == 0:
1013
+ logger.debug(
1014
+ f"sample={sample.shape}, pose_guider_emb={pose_guider_emb.shape}"
1015
+ )
1016
+ sample = sample + pose_guider_emb
1017
+
1018
+ if self.print_idx == 0:
1019
+ logger.debug(f"after conv in sample={sample.mean()}")
1020
+ if spatial_position_emb is not None:
1021
+ if self.print_idx == 0:
1022
+ logger.debug(
1023
+ f"unet3d, transformer_in, spatial_position_emb={spatial_position_emb.shape}"
1024
+ )
1025
+ if self.print_idx == 0:
1026
+ logger.debug(
1027
+ f"unet vision_conditon_frames_sample_index, {type(vision_conditon_frames_sample_index)}",
1028
+ )
1029
+ if vision_conditon_frames_sample_index is not None:
1030
+ if self.print_idx == 0:
1031
+ logger.debug(
1032
+ f"vision_conditon_frames_sample_index shape {vision_conditon_frames_sample_index.shape}",
1033
+ )
1034
+ if self.print_idx == 0:
1035
+ logger.debug(f"unet sample_index {type(sample_index)}")
1036
+ if sample_index is not None:
1037
+ if self.print_idx == 0:
1038
+ logger.debug(f"sample_index shape {sample_index.shape}")
1039
+ if self.need_transformer_in:
1040
+ if self.print_idx == 0:
1041
+ logger.debug(f"unet3d, transformer_in, sample={sample.shape}")
1042
+ sample = self.transformer_in(
1043
+ sample,
1044
+ femb=femb,
1045
+ num_frames=num_frames,
1046
+ cross_attention_kwargs=cross_attention_kwargs,
1047
+ encoder_hidden_states=encoder_hidden_states,
1048
+ sample_index=sample_index,
1049
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
1050
+ spatial_position_emb=spatial_position_emb,
1051
+ ).sample
1052
+ if (
1053
+ self.need_refer_emb
1054
+ and down_block_refer_embs is not None
1055
+ and not self.skip_refer_downblock_emb
1056
+ ):
1057
+ if self.print_idx == 0:
1058
+ logger.debug(
1059
+ f"self.first_refer_emb_attns, {self.first_refer_emb_attns.__class__.__name__} {down_block_refer_embs[0].shape}"
1060
+ )
1061
+ sample = self.first_refer_emb_attns(
1062
+ sample, down_block_refer_embs[0], num_frames=num_frames
1063
+ )
1064
+ if self.print_idx == 0:
1065
+ logger.debug(
1066
+ f"first_refer_emb_attns, sample is_leaf={sample.is_leaf}, requires_grad={sample.requires_grad}, down_block_refer_embs, {down_block_refer_embs[0].is_leaf}, {down_block_refer_embs[0].requires_grad},"
1067
+ )
1068
+ else:
1069
+ if self.print_idx == 0:
1070
+ logger.debug(f"first_refer_emb_attns, no this step")
1071
+ # 将 refer_self_attn_emb 转化成字典,增加一个当前index,表示block 的对应关系
1072
+ # convert refer_self_attn_emb to dict, add a current index to represent the corresponding relationship of the block
1073
+
1074
+ # 3. down
1075
+ down_block_res_samples = (sample,)
1076
+ for i_down_block, downsample_block in enumerate(self.down_blocks):
1077
+ # 使用 attn 的方式 来融合 refer_emb,这里是准备 downblock 对应的 refer_emb
1078
+ # fuse refer_emb with attn, here is to prepare the refer_emb corresponding to downblock
1079
+ if (
1080
+ not self.need_refer_emb
1081
+ or down_block_refer_embs is None
1082
+ or self.skip_refer_downblock_emb
1083
+ ):
1084
+ this_down_block_refer_embs = None
1085
+ if self.print_idx == 0:
1086
+ logger.debug(
1087
+ f"{i_down_block}, prepare this_down_block_refer_embs, is None"
1088
+ )
1089
+ else:
1090
+ is_final_block = i_down_block == len(self.block_out_channels) - 1
1091
+ num_block = self.layers_per_block + int(not is_final_block * 1)
1092
+ this_downblock_start_idx = 1 + num_block * i_down_block
1093
+ this_down_block_refer_embs = down_block_refer_embs[
1094
+ this_downblock_start_idx : this_downblock_start_idx + num_block
1095
+ ]
1096
+ if self.print_idx == 0:
1097
+ logger.debug(
1098
+ f"prepare this_down_block_refer_embs, {len(this_down_block_refer_embs)}, {this_down_block_refer_embs[0].shape}"
1099
+ )
1100
+ if self.print_idx == 0:
1101
+ logger.debug(f"downsample_block {i_down_block}, sample={sample.mean()}")
1102
+ if (
1103
+ hasattr(downsample_block, "has_cross_attention")
1104
+ and downsample_block.has_cross_attention
1105
+ ):
1106
+ sample, res_samples = downsample_block(
1107
+ hidden_states=sample,
1108
+ temb=emb,
1109
+ femb=femb,
1110
+ encoder_hidden_states=encoder_hidden_states,
1111
+ attention_mask=attention_mask,
1112
+ num_frames=num_frames,
1113
+ cross_attention_kwargs=cross_attention_kwargs,
1114
+ sample_index=sample_index,
1115
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
1116
+ spatial_position_emb=spatial_position_emb,
1117
+ refer_embs=this_down_block_refer_embs,
1118
+ refer_self_attn_emb=refer_self_attn_emb,
1119
+ refer_self_attn_emb_mode=refer_self_attn_emb_mode,
1120
+ )
1121
+ else:
1122
+ sample, res_samples = downsample_block(
1123
+ hidden_states=sample,
1124
+ temb=emb,
1125
+ femb=femb,
1126
+ num_frames=num_frames,
1127
+ sample_index=sample_index,
1128
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
1129
+ spatial_position_emb=spatial_position_emb,
1130
+ refer_embs=this_down_block_refer_embs,
1131
+ refer_self_attn_emb=refer_self_attn_emb,
1132
+ refer_self_attn_emb_mode=refer_self_attn_emb_mode,
1133
+ )
1134
+
1135
+ # resize spatial_position_emb
1136
+ if self.need_spatial_position_emb:
1137
+ has_downblock = i_down_block < len(self.down_blocks) - 1
1138
+ if has_downblock:
1139
+ spatial_position_emb = resize_spatial_position_emb(
1140
+ spatial_position_emb,
1141
+ scale=0.5,
1142
+ height=sample.shape[2] * 2,
1143
+ width=sample.shape[3] * 2,
1144
+ )
1145
+ down_block_res_samples += res_samples
1146
+ if down_block_additional_residuals is not None:
1147
+ new_down_block_res_samples = ()
1148
+ for down_block_res_sample, down_block_additional_residual in zip(
1149
+ down_block_res_samples, down_block_additional_residuals
1150
+ ):
1151
+ down_block_res_sample = (
1152
+ down_block_res_sample + down_block_additional_residual
1153
+ )
1154
+ new_down_block_res_samples += (down_block_res_sample,)
1155
+
1156
+ down_block_res_samples = new_down_block_res_samples
1157
+
1158
+ # 4. mid
1159
+ if self.mid_block is not None:
1160
+ sample = self.mid_block(
1161
+ hidden_states=sample,
1162
+ temb=emb,
1163
+ femb=femb,
1164
+ encoder_hidden_states=encoder_hidden_states,
1165
+ attention_mask=attention_mask,
1166
+ num_frames=num_frames,
1167
+ cross_attention_kwargs=cross_attention_kwargs,
1168
+ sample_index=sample_index,
1169
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
1170
+ spatial_position_emb=spatial_position_emb,
1171
+ refer_self_attn_emb=refer_self_attn_emb,
1172
+ refer_self_attn_emb_mode=refer_self_attn_emb_mode,
1173
+ )
1174
+ # 使用 attn 的方式 来融合 mid_block_refer_emb
1175
+ # fuse mid_block_refer_emb with attn
1176
+ if (
1177
+ self.mid_block_refer_emb_attns is not None
1178
+ and mid_block_refer_emb is not None
1179
+ and not self.skip_refer_downblock_emb
1180
+ ):
1181
+ if self.print_idx == 0:
1182
+ logger.debug(
1183
+ f"self.mid_block_refer_emb_attns={self.mid_block_refer_emb_attns}, mid_block_refer_emb={mid_block_refer_emb.shape}"
1184
+ )
1185
+ sample = self.mid_block_refer_emb_attns(
1186
+ sample, mid_block_refer_emb, num_frames=num_frames
1187
+ )
1188
+ if self.print_idx == 0:
1189
+ logger.debug(
1190
+ f"mid_block_refer_emb_attns, sample is_leaf={sample.is_leaf}, requires_grad={sample.requires_grad}, mid_block_refer_emb, {mid_block_refer_emb[0].is_leaf}, {mid_block_refer_emb[0].requires_grad},"
1191
+ )
1192
+ else:
1193
+ if self.print_idx == 0:
1194
+ logger.debug(f"mid_block_refer_emb_attns, no this step")
1195
+ if mid_block_additional_residual is not None:
1196
+ sample = sample + mid_block_additional_residual
1197
+
1198
+ # 5. up
1199
+ for i_up_block, upsample_block in enumerate(self.up_blocks):
1200
+ is_final_block = i_up_block == len(self.up_blocks) - 1
1201
+
1202
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1203
+ down_block_res_samples = down_block_res_samples[
1204
+ : -len(upsample_block.resnets)
1205
+ ]
1206
+
1207
+ # if we have not reached the final block and need to forward the
1208
+ # upsample size, we do it here
1209
+ if not is_final_block and forward_upsample_size:
1210
+ upsample_size = down_block_res_samples[-1].shape[2:]
1211
+
1212
+ if (
1213
+ hasattr(upsample_block, "has_cross_attention")
1214
+ and upsample_block.has_cross_attention
1215
+ ):
1216
+ sample = upsample_block(
1217
+ hidden_states=sample,
1218
+ temb=emb,
1219
+ femb=femb,
1220
+ res_hidden_states_tuple=res_samples,
1221
+ encoder_hidden_states=encoder_hidden_states,
1222
+ upsample_size=upsample_size,
1223
+ attention_mask=attention_mask,
1224
+ num_frames=num_frames,
1225
+ cross_attention_kwargs=cross_attention_kwargs,
1226
+ sample_index=sample_index,
1227
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
1228
+ spatial_position_emb=spatial_position_emb,
1229
+ refer_self_attn_emb=refer_self_attn_emb,
1230
+ refer_self_attn_emb_mode=refer_self_attn_emb_mode,
1231
+ )
1232
+ else:
1233
+ sample = upsample_block(
1234
+ hidden_states=sample,
1235
+ temb=emb,
1236
+ femb=femb,
1237
+ res_hidden_states_tuple=res_samples,
1238
+ upsample_size=upsample_size,
1239
+ num_frames=num_frames,
1240
+ sample_index=sample_index,
1241
+ vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
1242
+ spatial_position_emb=spatial_position_emb,
1243
+ refer_self_attn_emb=refer_self_attn_emb,
1244
+ refer_self_attn_emb_mode=refer_self_attn_emb_mode,
1245
+ )
1246
+ # resize spatial_position_emb
1247
+ if self.need_spatial_position_emb:
1248
+ has_upblock = i_up_block < len(self.up_blocks) - 1
1249
+ if has_upblock:
1250
+ spatial_position_emb = resize_spatial_position_emb(
1251
+ spatial_position_emb,
1252
+ scale=2,
1253
+ height=int(sample.shape[2] / 2),
1254
+ width=int(sample.shape[3] / 2),
1255
+ )
1256
+
1257
+ # 6. post-process
1258
+ if self.conv_norm_out:
1259
+ sample = self.conv_norm_out(sample)
1260
+ sample = self.conv_act(sample)
1261
+
1262
+ sample = self.conv_out(sample)
1263
+ sample = rearrange(sample, "(b t) c h w -> b c t h w", t=num_frames)
1264
+
1265
+ # if self.need_adain_temporal_cond and num_frames > 1:
1266
+ # sample = batch_adain_conditioned_tensor(
1267
+ # sample,
1268
+ # num_frames=num_frames,
1269
+ # need_style_fidelity=False,
1270
+ # src_index=sample_index,
1271
+ # dst_index=vision_conditon_frames_sample_index,
1272
+ # )
1273
+ self.print_idx += 1
1274
+
1275
+ if skip_temporal_layers is not None:
1276
+ self.set_skip_temporal_layers(not skip_temporal_layers)
1277
+ if not return_dict:
1278
+ return (sample,)
1279
+ else:
1280
+ return UNet3DConditionOutput(sample=sample)
1281
+
1282
+ # from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/modeling_utils.py#L328
1283
+ @classmethod
1284
+ def from_pretrained_2d(
1285
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs
1286
+ ):
1287
+ r"""
1288
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
1289
+
1290
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
1291
+ the model, you should first set it back in training mode with `model.train()`.
1292
+
1293
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
1294
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
1295
+ task.
1296
+
1297
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
1298
+ weights are discarded.
1299
+
1300
+ Parameters:
1301
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1302
+ Can be either:
1303
+
1304
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
1305
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
1306
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
1307
+ `./my_model_directory/`.
1308
+
1309
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1310
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
1311
+ standard cache should not be used.
1312
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1313
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
1314
+ will be automatically derived from the model's weights.
1315
+ force_download (`bool`, *optional*, defaults to `False`):
1316
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1317
+ cached versions if they exist.
1318
+ resume_download (`bool`, *optional*, defaults to `False`):
1319
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
1320
+ file exists.
1321
+ proxies (`Dict[str, str]`, *optional*):
1322
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1323
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1324
+ output_loading_info(`bool`, *optional*, defaults to `False`):
1325
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1326
+ local_files_only(`bool`, *optional*, defaults to `False`):
1327
+ Whether or not to only look at local files (i.e., do not try to download the model).
1328
+ use_auth_token (`str` or *bool*, *optional*):
1329
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
1330
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
1331
+ revision (`str`, *optional*, defaults to `"main"`):
1332
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1333
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1334
+ identifier allowed by git.
1335
+ from_flax (`bool`, *optional*, defaults to `False`):
1336
+ Load the model weights from a Flax checkpoint save file.
1337
+ subfolder (`str`, *optional*, defaults to `""`):
1338
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
1339
+ huggingface.co or downloaded locally), you can specify the folder name here.
1340
+
1341
+ mirror (`str`, *optional*):
1342
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
1343
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
1344
+ Please refer to the mirror site for more information.
1345
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1346
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
1347
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
1348
+ same device.
1349
+
1350
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
1351
+ more information about each option see [designing a device
1352
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1353
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1354
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
1355
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
1356
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
1357
+ setting this argument to `True` will raise an error.
1358
+ variant (`str`, *optional*):
1359
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
1360
+ ignored when using `from_flax`.
1361
+ use_safetensors (`bool`, *optional* ):
1362
+ If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to
1363
+ `None` (the default). The pipeline will load using `safetensors` if safetensors weights are available
1364
+ *and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
1365
+
1366
+ <Tip>
1367
+
1368
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
1369
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
1370
+
1371
+ </Tip>
1372
+
1373
+ <Tip>
1374
+
1375
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
1376
+ this method in a firewalled environment.
1377
+
1378
+ </Tip>
1379
+
1380
+ """
1381
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1382
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1383
+ force_download = kwargs.pop("force_download", False)
1384
+ from_flax = kwargs.pop("from_flax", False)
1385
+ resume_download = kwargs.pop("resume_download", False)
1386
+ proxies = kwargs.pop("proxies", None)
1387
+ output_loading_info = kwargs.pop("output_loading_info", False)
1388
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1389
+ use_auth_token = kwargs.pop("use_auth_token", None)
1390
+ revision = kwargs.pop("revision", None)
1391
+ torch_dtype = kwargs.pop("torch_dtype", None)
1392
+ subfolder = kwargs.pop("subfolder", None)
1393
+ device_map = kwargs.pop("device_map", None)
1394
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
1395
+ variant = kwargs.pop("variant", None)
1396
+ use_safetensors = kwargs.pop("use_safetensors", None)
1397
+ strict = kwargs.pop("strict", True)
1398
+
1399
+ allow_pickle = False
1400
+ if use_safetensors is None:
1401
+ allow_pickle = True
1402
+
1403
+ if low_cpu_mem_usage and not is_accelerate_available():
1404
+ low_cpu_mem_usage = False
1405
+ logger.warning(
1406
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
1407
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
1408
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
1409
+ " install accelerate\n```\n."
1410
+ )
1411
+
1412
+ if device_map is not None and not is_accelerate_available():
1413
+ raise NotImplementedError(
1414
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1415
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1416
+ )
1417
+
1418
+ # Check if we can handle device_map and dispatching the weights
1419
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1420
+ raise NotImplementedError(
1421
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1422
+ " `device_map=None`."
1423
+ )
1424
+
1425
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
1426
+ raise NotImplementedError(
1427
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
1428
+ " `low_cpu_mem_usage=False`."
1429
+ )
1430
+
1431
+ if low_cpu_mem_usage is False and device_map is not None:
1432
+ raise ValueError(
1433
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
1434
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
1435
+ )
1436
+
1437
+ # Load config if we don't provide a configuration
1438
+ config_path = pretrained_model_name_or_path
1439
+
1440
+ user_agent = {
1441
+ "diffusers": __version__,
1442
+ "file_type": "model",
1443
+ "framework": "pytorch",
1444
+ }
1445
+
1446
+ # load config
1447
+ config, unused_kwargs, commit_hash = cls.load_config(
1448
+ config_path,
1449
+ cache_dir=cache_dir,
1450
+ return_unused_kwargs=True,
1451
+ return_commit_hash=True,
1452
+ force_download=force_download,
1453
+ resume_download=resume_download,
1454
+ proxies=proxies,
1455
+ local_files_only=local_files_only,
1456
+ use_auth_token=use_auth_token,
1457
+ revision=revision,
1458
+ subfolder=subfolder,
1459
+ device_map=device_map,
1460
+ user_agent=user_agent,
1461
+ **kwargs,
1462
+ )
1463
+
1464
+ config["_class_name"] = cls.__name__
1465
+ config["down_block_types"] = convert_2D_to_3D(config["down_block_types"])
1466
+ if "mid_block_type" in config:
1467
+ config["mid_block_type"] = convert_2D_to_3D(config["mid_block_type"])
1468
+ else:
1469
+ config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
1470
+ config["up_block_types"] = convert_2D_to_3D(config["up_block_types"])
1471
+
1472
+ # load model
1473
+ model_file = None
1474
+ if from_flax:
1475
+ model_file = _get_model_file(
1476
+ pretrained_model_name_or_path,
1477
+ weights_name=FLAX_WEIGHTS_NAME,
1478
+ cache_dir=cache_dir,
1479
+ force_download=force_download,
1480
+ resume_download=resume_download,
1481
+ proxies=proxies,
1482
+ local_files_only=local_files_only,
1483
+ use_auth_token=use_auth_token,
1484
+ revision=revision,
1485
+ subfolder=subfolder,
1486
+ user_agent=user_agent,
1487
+ commit_hash=commit_hash,
1488
+ )
1489
+ model = cls.from_config(config, **unused_kwargs)
1490
+
1491
+ # Convert the weights
1492
+ from diffusers.models.modeling_pytorch_flax_utils import (
1493
+ load_flax_checkpoint_in_pytorch_model,
1494
+ )
1495
+
1496
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
1497
+ else:
1498
+ try:
1499
+ model_file = _get_model_file(
1500
+ pretrained_model_name_or_path,
1501
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1502
+ cache_dir=cache_dir,
1503
+ force_download=force_download,
1504
+ resume_download=resume_download,
1505
+ proxies=proxies,
1506
+ local_files_only=local_files_only,
1507
+ use_auth_token=use_auth_token,
1508
+ revision=revision,
1509
+ subfolder=subfolder,
1510
+ user_agent=user_agent,
1511
+ commit_hash=commit_hash,
1512
+ )
1513
+ except IOError as e:
1514
+ if not allow_pickle:
1515
+ raise e
1516
+ pass
1517
+ if model_file is None:
1518
+ model_file = _get_model_file(
1519
+ pretrained_model_name_or_path,
1520
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1521
+ cache_dir=cache_dir,
1522
+ force_download=force_download,
1523
+ resume_download=resume_download,
1524
+ proxies=proxies,
1525
+ local_files_only=local_files_only,
1526
+ use_auth_token=use_auth_token,
1527
+ revision=revision,
1528
+ subfolder=subfolder,
1529
+ user_agent=user_agent,
1530
+ commit_hash=commit_hash,
1531
+ )
1532
+
1533
+ if low_cpu_mem_usage:
1534
+ # Instantiate model with empty weights
1535
+ with accelerate.init_empty_weights():
1536
+ model = cls.from_config(config, **unused_kwargs)
1537
+
1538
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
1539
+ if device_map is None:
1540
+ param_device = "cpu"
1541
+ state_dict = load_state_dict(model_file, variant=variant)
1542
+ # move the params from meta device to cpu
1543
+ missing_keys = set(model.state_dict().keys()) - set(
1544
+ state_dict.keys()
1545
+ )
1546
+ if len(missing_keys) > 0:
1547
+ if strict:
1548
+ raise ValueError(
1549
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
1550
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
1551
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
1552
+ " those weights or else make sure your checkpoint file is correct."
1553
+ )
1554
+ else:
1555
+ logger.warning(
1556
+ f"model{cls} has no target pretrained paramter from {pretrained_model_name_or_path}, {', '.join(missing_keys)}"
1557
+ )
1558
+
1559
+ empty_state_dict = model.state_dict()
1560
+ for param_name, param in state_dict.items():
1561
+ accepts_dtype = "dtype" in set(
1562
+ inspect.signature(
1563
+ set_module_tensor_to_device
1564
+ ).parameters.keys()
1565
+ )
1566
+
1567
+ if empty_state_dict[param_name].shape != param.shape:
1568
+ raise ValueError(
1569
+ f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
1570
+ )
1571
+
1572
+ if accepts_dtype:
1573
+ set_module_tensor_to_device(
1574
+ model,
1575
+ param_name,
1576
+ param_device,
1577
+ value=param,
1578
+ dtype=torch_dtype,
1579
+ )
1580
+ else:
1581
+ set_module_tensor_to_device(
1582
+ model, param_name, param_device, value=param
1583
+ )
1584
+ else: # else let accelerate handle loading and dispatching.
1585
+ # Load weights and dispatch according to the device_map
1586
+ # by default the device_map is None and the weights are loaded on the CPU
1587
+ accelerate.load_checkpoint_and_dispatch(
1588
+ model, model_file, device_map, dtype=torch_dtype
1589
+ )
1590
+
1591
+ loading_info = {
1592
+ "missing_keys": [],
1593
+ "unexpected_keys": [],
1594
+ "mismatched_keys": [],
1595
+ "error_msgs": [],
1596
+ }
1597
+ else:
1598
+ model = cls.from_config(config, **unused_kwargs)
1599
+
1600
+ state_dict = load_state_dict(model_file, variant=variant)
1601
+
1602
+ (
1603
+ model,
1604
+ missing_keys,
1605
+ unexpected_keys,
1606
+ mismatched_keys,
1607
+ error_msgs,
1608
+ ) = cls._load_pretrained_model(
1609
+ model,
1610
+ state_dict,
1611
+ model_file,
1612
+ pretrained_model_name_or_path,
1613
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
1614
+ )
1615
+
1616
+ loading_info = {
1617
+ "missing_keys": missing_keys,
1618
+ "unexpected_keys": unexpected_keys,
1619
+ "mismatched_keys": mismatched_keys,
1620
+ "error_msgs": error_msgs,
1621
+ }
1622
+
1623
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1624
+ raise ValueError(
1625
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1626
+ )
1627
+ elif torch_dtype is not None:
1628
+ model = model.to(torch_dtype)
1629
+
1630
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1631
+
1632
+ # Set model in evaluation mode to deactivate DropOut modules by default
1633
+ model.eval()
1634
+ if output_loading_info:
1635
+ return model, loading_info
1636
+
1637
+ return model
1638
+
1639
+ def set_skip_temporal_layers(
1640
+ self,
1641
+ valid: bool,
1642
+ ) -> None: # turn 3Dunet to 2Dunet
1643
+ # Recursively walk through all the children.
1644
+ # Any children which exposes the skip_temporal_layers parameter gets the message
1645
+
1646
+ # 推断时使用参数控制refer_image和ip_adapter_image来控制,不需要这里了
1647
+ # if hasattr(self, "skip_refer_downblock_emb"):
1648
+ # self.skip_refer_downblock_emb = valid
1649
+
1650
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
1651
+ if hasattr(module, "skip_temporal_layers"):
1652
+ module.skip_temporal_layers = valid
1653
+ # if hasattr(module, "skip_refer_downblock_emb"):
1654
+ # module.skip_refer_downblock_emb = valid
1655
+
1656
+ for child in module.children():
1657
+ fn_recursive_set_mem_eff(child)
1658
+
1659
+ for module in self.children():
1660
+ if isinstance(module, torch.nn.Module):
1661
+ fn_recursive_set_mem_eff(module)
1662
+
1663
+ def insert_spatial_self_attn_idx(self):
1664
+ attns, basic_transformers = self.spatial_self_attns
1665
+ self.self_attn_num = len(attns)
1666
+ for i, (name, layer) in enumerate(attns):
1667
+ logger.debug(
1668
+ f"{self.__class__.__name__}, {i}, {name}, {layer.__class__.__name__}"
1669
+ )
1670
+ layer.spatial_self_attn_idx = i
1671
+ for i, (name, layer) in enumerate(basic_transformers):
1672
+ logger.debug(
1673
+ f"{self.__class__.__name__}, {i}, {name}, {layer.__class__.__name__}"
1674
+ )
1675
+ layer.spatial_self_attn_idx = i
1676
+
1677
+ @property
1678
+ def spatial_self_attns(
1679
+ self,
1680
+ ) -> List[Tuple[str, Attention]]:
1681
+ attns, spatial_transformers = self.get_attns(
1682
+ include="attentions", exclude="temp_attentions", attn_name="attn1"
1683
+ )
1684
+ attns = sorted(attns)
1685
+ spatial_transformers = sorted(spatial_transformers)
1686
+ return attns, spatial_transformers
1687
+
1688
+ @property
1689
+ def spatial_cross_attns(
1690
+ self,
1691
+ ) -> List[Tuple[str, Attention]]:
1692
+ attns, spatial_transformers = self.get_attns(
1693
+ include="attentions", exclude="temp_attentions", attn_name="attn2"
1694
+ )
1695
+ attns = sorted(attns)
1696
+ spatial_transformers = sorted(spatial_transformers)
1697
+ return attns, spatial_transformers
1698
+
1699
+ def get_attns(
1700
+ self,
1701
+ attn_name: str,
1702
+ include: str = None,
1703
+ exclude: str = None,
1704
+ ) -> List[Tuple[str, Attention]]:
1705
+ r"""
1706
+ Returns:
1707
+ `dict` of attention attns: A dictionary containing all attention attns used in the model with
1708
+ indexed by its weight name.
1709
+ """
1710
+ # set recursively
1711
+ attns = []
1712
+ spatial_transformers = []
1713
+
1714
+ def fn_recursive_add_attns(
1715
+ name: str,
1716
+ module: torch.nn.Module,
1717
+ attns: List[Tuple[str, Attention]],
1718
+ spatial_transformers: List[Tuple[str, BasicTransformerBlock]],
1719
+ ):
1720
+ is_target = False
1721
+ if isinstance(module, BasicTransformerBlock) and hasattr(module, attn_name):
1722
+ is_target = True
1723
+ if include is not None:
1724
+ is_target = include in name
1725
+ if exclude is not None:
1726
+ is_target = exclude not in name
1727
+ if is_target:
1728
+ attns.append([f"{name}.{attn_name}", getattr(module, attn_name)])
1729
+ spatial_transformers.append([f"{name}", module])
1730
+ for sub_name, child in module.named_children():
1731
+ fn_recursive_add_attns(
1732
+ f"{name}.{sub_name}", child, attns, spatial_transformers
1733
+ )
1734
+
1735
+ return attns
1736
+
1737
+ for name, module in self.named_children():
1738
+ fn_recursive_add_attns(name, module, attns, spatial_transformers)
1739
+
1740
+ return attns, spatial_transformers
musev/models/unet_loader.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.modeling_utils import load_state_dict
31
+ from diffusers.utils import (
32
+ logging,
33
+ )
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+
36
+ from ..models.unet_3d_condition import UNet3DConditionModel
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ def update_unet_with_sd(
42
+ unet: nn.Module, sd_model: Tuple[str, nn.Module], subfolder: str = "unet"
43
+ ):
44
+ """更新T2V模型中的T2I参数. update t2i parameters in t2v model
45
+
46
+ Args:
47
+ unet (nn.Module): _description_
48
+ sd_model (Tuple[str, nn.Module]): _description_
49
+
50
+ Returns:
51
+ _type_: _description_
52
+ """
53
+ # dtype = unet.dtype
54
+ # TODO: in this way, sd_model_path must be absolute path, to be more dynamic
55
+ if isinstance(sd_model, str):
56
+ if os.path.isdir(sd_model):
57
+ unet_state_dict = load_state_dict(
58
+ os.path.join(sd_model, subfolder, "diffusion_pytorch_model.bin"),
59
+ )
60
+ elif os.path.isfile(sd_model):
61
+ if sd_model.endswith("pth"):
62
+ unet_state_dict = torch.load(sd_model, map_location="cpu")
63
+ print(f"referencenet successful load ={sd_model} with torch.load")
64
+ else:
65
+ try:
66
+ unet_state_dict = load_state_dict(sd_model)
67
+ print(
68
+ f"referencenet successful load with {sd_model} with load_state_dict"
69
+ )
70
+ except Exception as e:
71
+ print(e)
72
+
73
+ elif isinstance(sd_model, nn.Module):
74
+ unet_state_dict = sd_model.state_dict()
75
+ else:
76
+ raise ValueError(f"given {type(sd_model)}, but only support nn.Module or str")
77
+ missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False)
78
+ assert len(unexpected) == 0, f"unet load_state_dict error, unexpected={unexpected}"
79
+ # unet.to(dtype=dtype)
80
+ return unet
81
+
82
+
83
+ def load_unet(
84
+ sd_unet_model: Tuple[str, nn.Module],
85
+ sd_model: Tuple[str, nn.Module] = None,
86
+ cross_attention_dim: int = 768,
87
+ temporal_transformer: str = "TransformerTemporalModel",
88
+ temporal_conv_block: str = "TemporalConvLayer",
89
+ need_spatial_position_emb: bool = False,
90
+ need_transformer_in: bool = True,
91
+ need_t2i_ip_adapter: bool = False,
92
+ need_adain_temporal_cond: bool = False,
93
+ t2i_ip_adapter_attn_processor: str = "IPXFormersAttnProcessor",
94
+ keep_vision_condtion: bool = False,
95
+ use_anivv1_cfg: bool = False,
96
+ resnet_2d_skip_time_act: bool = False,
97
+ dtype: torch.dtype = torch.float16,
98
+ need_zero_vis_cond_temb: bool = True,
99
+ norm_spatial_length: bool = True,
100
+ spatial_max_length: int = 2048,
101
+ need_refer_emb: bool = False,
102
+ ip_adapter_cross_attn=False,
103
+ t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor",
104
+ need_t2i_facein: bool = False,
105
+ need_t2i_ip_adapter_face: bool = False,
106
+ strict: bool = True,
107
+ ):
108
+ """通过模型名字 初始化Unet,载入预训练参数. init unet with model_name.
109
+ 该部分都是通过 models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型
110
+ model is defined and trained in models.unet_3d_condition.py:UNet3DConditionModel
111
+
112
+ Args:
113
+ sd_unet_model (Tuple[str, nn.Module]): _description_
114
+ sd_model (Tuple[str, nn.Module]): _description_
115
+ cross_attention_dim (int, optional): _description_. Defaults to 768.
116
+ temporal_transformer (str, optional): _description_. Defaults to "TransformerTemporalModel".
117
+ temporal_conv_block (str, optional): _description_. Defaults to "TemporalConvLayer".
118
+ need_spatial_position_emb (bool, optional): _description_. Defaults to False.
119
+ need_transformer_in (bool, optional): _description_. Defaults to True.
120
+ need_t2i_ip_adapter (bool, optional): _description_. Defaults to False.
121
+ need_adain_temporal_cond (bool, optional): _description_. Defaults to False.
122
+ t2i_ip_adapter_attn_processor (str, optional): _description_. Defaults to "IPXFormersAttnProcessor".
123
+ keep_vision_condtion (bool, optional): _description_. Defaults to False.
124
+ use_anivv1_cfg (bool, optional): _description_. Defaults to False.
125
+ resnet_2d_skip_time_act (bool, optional): _description_. Defaults to False.
126
+ dtype (torch.dtype, optional): _description_. Defaults to torch.float16.
127
+ need_zero_vis_cond_temb (bool, optional): _description_. Defaults to True.
128
+ norm_spatial_length (bool, optional): _description_. Defaults to True.
129
+ spatial_max_length (int, optional): _description_. Defaults to 2048.
130
+
131
+ Returns:
132
+ _type_: _description_
133
+ """
134
+ if isinstance(sd_unet_model, str):
135
+ unet = UNet3DConditionModel.from_pretrained_2d(
136
+ sd_unet_model,
137
+ subfolder="unet",
138
+ temporal_transformer=temporal_transformer,
139
+ temporal_conv_block=temporal_conv_block,
140
+ cross_attention_dim=cross_attention_dim,
141
+ need_spatial_position_emb=need_spatial_position_emb,
142
+ need_transformer_in=need_transformer_in,
143
+ need_t2i_ip_adapter=need_t2i_ip_adapter,
144
+ need_adain_temporal_cond=need_adain_temporal_cond,
145
+ t2i_ip_adapter_attn_processor=t2i_ip_adapter_attn_processor,
146
+ keep_vision_condtion=keep_vision_condtion,
147
+ use_anivv1_cfg=use_anivv1_cfg,
148
+ resnet_2d_skip_time_act=resnet_2d_skip_time_act,
149
+ torch_dtype=dtype,
150
+ need_zero_vis_cond_temb=need_zero_vis_cond_temb,
151
+ norm_spatial_length=norm_spatial_length,
152
+ spatial_max_length=spatial_max_length,
153
+ need_refer_emb=need_refer_emb,
154
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
155
+ t2i_crossattn_ip_adapter_attn_processor=t2i_crossattn_ip_adapter_attn_processor,
156
+ need_t2i_facein=need_t2i_facein,
157
+ strict=strict,
158
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
159
+ )
160
+ elif isinstance(sd_unet_model, nn.Module):
161
+ unet = sd_unet_model
162
+ if sd_model is not None:
163
+ unet = update_unet_with_sd(unet, sd_model)
164
+ return unet
165
+
166
+
167
+ def load_unet_custom_unet(
168
+ sd_unet_model: Tuple[str, nn.Module],
169
+ sd_model: Tuple[str, nn.Module],
170
+ unet_class: nn.Module,
171
+ ):
172
+ """
173
+ 通过模型名字 初始化Unet,载入预训练参数. init unet with model_name.
174
+ 该部分都是通过 不通过models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型
175
+ model is not defined in models.unet_3d_condition.py:UNet3DConditionModel
176
+ Args:
177
+ sd_unet_model (Tuple[str, nn.Module]): _description_
178
+ sd_model (Tuple[str, nn.Module]): _description_
179
+ unet_class (nn.Module): _description_
180
+
181
+ Returns:
182
+ _type_: _description_
183
+ """
184
+ if isinstance(sd_unet_model, str):
185
+ unet = unet_class.from_pretrained(
186
+ sd_unet_model,
187
+ subfolder="unet",
188
+ )
189
+ elif isinstance(sd_unet_model, nn.Module):
190
+ unet = sd_unet_model
191
+
192
+ # TODO: in this way, sd_model_path must be absolute path, to be more dynamic
193
+ if isinstance(sd_model, str):
194
+ unet_state_dict = load_state_dict(
195
+ os.path.join(sd_model, "unet/diffusion_pytorch_model.bin"),
196
+ )
197
+ elif isinstance(sd_model, nn.Module):
198
+ unet_state_dict = sd_model.state_dict()
199
+ missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False)
200
+ assert (
201
+ len(unexpected) == 0
202
+ ), "unet load_state_dict error" # Load scheduler, tokenizer and models.
203
+ return unet
204
+
205
+
206
+ def load_unet_by_name(
207
+ model_name: str,
208
+ sd_unet_model: Tuple[str, nn.Module],
209
+ sd_model: Tuple[str, nn.Module] = None,
210
+ cross_attention_dim: int = 768,
211
+ dtype: torch.dtype = torch.float16,
212
+ need_t2i_facein: bool = False,
213
+ need_t2i_ip_adapter_face: bool = False,
214
+ strict: bool = True,
215
+ ) -> nn.Module:
216
+ """通过模型名字 初始化Unet,载入预训练参数. init unet with model_name.
217
+ 如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义
218
+ if you want to use pretrained model with simple name, you need to define it here.
219
+ Args:
220
+ model_name (str): _description_
221
+ sd_unet_model (Tuple[str, nn.Module]): _description_
222
+ sd_model (Tuple[str, nn.Module]): _description_
223
+ cross_attention_dim (int, optional): _description_. Defaults to 768.
224
+ dtype (torch.dtype, optional): _description_. Defaults to torch.float16.
225
+
226
+ Raises:
227
+ ValueError: _description_
228
+
229
+ Returns:
230
+ nn.Module: _description_
231
+ """
232
+ if model_name in ["musev"]:
233
+ unet = load_unet(
234
+ sd_unet_model=sd_unet_model,
235
+ sd_model=sd_model,
236
+ need_spatial_position_emb=False,
237
+ cross_attention_dim=cross_attention_dim,
238
+ need_t2i_ip_adapter=True,
239
+ need_adain_temporal_cond=True,
240
+ t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor",
241
+ dtype=dtype,
242
+ )
243
+ elif model_name in [
244
+ "musev_referencenet",
245
+ "musev_referencenet_pose",
246
+ ]:
247
+ unet = load_unet(
248
+ sd_unet_model=sd_unet_model,
249
+ sd_model=sd_model,
250
+ cross_attention_dim=cross_attention_dim,
251
+ temporal_conv_block="TemporalConvLayer",
252
+ need_transformer_in=False,
253
+ temporal_transformer="TransformerTemporalModel",
254
+ use_anivv1_cfg=True,
255
+ resnet_2d_skip_time_act=True,
256
+ need_t2i_ip_adapter=True,
257
+ need_adain_temporal_cond=True,
258
+ keep_vision_condtion=True,
259
+ t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor",
260
+ dtype=dtype,
261
+ need_refer_emb=True,
262
+ need_zero_vis_cond_temb=True,
263
+ ip_adapter_cross_attn=True,
264
+ t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor",
265
+ need_t2i_facein=need_t2i_facein,
266
+ strict=strict,
267
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
268
+ )
269
+ else:
270
+ raise ValueError(
271
+ f"unsupport model_name={model_name}, only support musev, musev_referencenet, musev_referencenet_pose"
272
+ )
273
+ return unet
musev/pipelines/__init__.py ADDED
File without changes
musev/pipelines/context.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Adapted from cli
2
+ import math
3
+ from typing import Callable, List, Optional
4
+
5
+ import numpy as np
6
+
7
+ from mmcm.utils.itertools_util import generate_sample_idxs
8
+
9
+ # copy from https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/context.py
10
+
11
+
12
+ def ordered_halving(val):
13
+ bin_str = f"{val:064b}"
14
+ bin_flip = bin_str[::-1]
15
+ as_int = int(bin_flip, 2)
16
+
17
+ return as_int / (1 << 64)
18
+
19
+
20
+ # TODO: closed_loop not work, to fix it
21
+ def uniform(
22
+ step: int = ...,
23
+ num_steps: Optional[int] = None,
24
+ num_frames: int = ...,
25
+ context_size: Optional[int] = None,
26
+ context_stride: int = 3,
27
+ context_overlap: int = 4,
28
+ closed_loop: bool = True,
29
+ ):
30
+ if num_frames <= context_size:
31
+ yield list(range(num_frames))
32
+ return
33
+
34
+ context_stride = min(
35
+ context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
36
+ )
37
+
38
+ for context_step in 1 << np.arange(context_stride):
39
+ pad = int(round(num_frames * ordered_halving(step)))
40
+ for j in range(
41
+ int(ordered_halving(step) * context_step) + pad,
42
+ num_frames + pad + (0 if closed_loop else -context_overlap),
43
+ (context_size * context_step - context_overlap),
44
+ ):
45
+ yield [
46
+ e % num_frames
47
+ for e in range(j, j + context_size * context_step, context_step)
48
+ ]
49
+
50
+
51
+ def uniform_v2(
52
+ step: int = ...,
53
+ num_steps: Optional[int] = None,
54
+ num_frames: int = ...,
55
+ context_size: Optional[int] = None,
56
+ context_stride: int = 3,
57
+ context_overlap: int = 4,
58
+ closed_loop: bool = True,
59
+ ):
60
+ return generate_sample_idxs(
61
+ total=num_frames,
62
+ window_size=context_size,
63
+ step=context_size - context_overlap,
64
+ sample_rate=1,
65
+ drop_last=False,
66
+ )
67
+
68
+
69
+ def get_context_scheduler(name: str) -> Callable:
70
+ if name == "uniform":
71
+ return uniform
72
+ elif name == "uniform_v2":
73
+ return uniform_v2
74
+ else:
75
+ raise ValueError(f"Unknown context_overlap policy {name}")
76
+
77
+
78
+ def get_total_steps(
79
+ scheduler,
80
+ timesteps: List[int],
81
+ num_steps: Optional[int] = None,
82
+ num_frames: int = ...,
83
+ context_size: Optional[int] = None,
84
+ context_stride: int = 3,
85
+ context_overlap: int = 4,
86
+ closed_loop: bool = True,
87
+ ):
88
+ return sum(
89
+ len(
90
+ list(
91
+ scheduler(
92
+ i,
93
+ num_steps,
94
+ num_frames,
95
+ context_size,
96
+ context_stride,
97
+ context_overlap,
98
+ )
99
+ )
100
+ )
101
+ for i in range(len(timesteps))
102
+ )
103
+
104
+
105
+ def drop_last_repeat_context(contexts: List[List[int]]) -> List[List[int]]:
106
+ """if len(contexts)>=2 and the max value the oenultimate list same as of the last list
107
+
108
+ Args:
109
+ List (_type_): _description_
110
+
111
+ Returns:
112
+ List[List[int]]: _description_
113
+ """
114
+ if len(contexts) >= 2 and contexts[-1][-1] == contexts[-2][-1]:
115
+ return contexts[:-1]
116
+ else:
117
+ return contexts
118
+
119
+
120
+ def prepare_global_context(
121
+ context_schedule: str,
122
+ num_inference_steps: int,
123
+ time_size: int,
124
+ context_frames: int,
125
+ context_stride: int,
126
+ context_overlap: int,
127
+ context_batch_size: int,
128
+ ):
129
+ context_scheduler = get_context_scheduler(context_schedule)
130
+ context_queue = list(
131
+ context_scheduler(
132
+ step=0,
133
+ num_steps=num_inference_steps,
134
+ num_frames=time_size,
135
+ context_size=context_frames,
136
+ context_stride=context_stride,
137
+ context_overlap=context_overlap,
138
+ )
139
+ )
140
+ # 如果context_queue的最后一个索引最大值和倒数第二个索引最大值相同,说明最后一个列表就是因为step带来的冗余项,可以去掉
141
+ # remove the last context if max index of the last context is the same as the max index of the second last context
142
+ context_queue = drop_last_repeat_context(context_queue)
143
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
144
+ global_context = []
145
+ for i_tmp in range(num_context_batches):
146
+ global_context.append(
147
+ context_queue[i_tmp * context_batch_size : (i_tmp + 1) * context_batch_size]
148
+ )
149
+ return global_context
musev/pipelines/pipeline_controlnet.py ADDED
The diff for this file is too large to render. See raw diff
 
musev/pipelines/pipeline_controlnet_predictor.py ADDED
@@ -0,0 +1,1290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pformat, pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.autoencoder_kl import AutoencoderKL
31
+
32
+ from diffusers.models.modeling_utils import load_state_dict
33
+ from diffusers.utils import (
34
+ logging,
35
+ BaseOutput,
36
+ logging,
37
+ )
38
+ from diffusers.utils.dummy_pt_objects import ConsistencyDecoderVAE
39
+ from diffusers.utils.import_utils import is_xformers_available
40
+
41
+ from mmcm.utils.seed_util import set_all_seed
42
+ from mmcm.vision.data.video_dataset import DecordVideoDataset
43
+ from mmcm.vision.process.correct_color import hist_match_video_bcthw
44
+ from mmcm.vision.process.image_process import (
45
+ batch_dynamic_crop_resize_images,
46
+ batch_dynamic_crop_resize_images_v2,
47
+ )
48
+ from mmcm.vision.utils.data_type_util import is_video
49
+ from mmcm.vision.feature_extractor.controlnet import load_controlnet_model
50
+
51
+ from ..schedulers import (
52
+ EulerDiscreteScheduler,
53
+ LCMScheduler,
54
+ DDIMScheduler,
55
+ DDPMScheduler,
56
+ )
57
+ from ..models.unet_3d_condition import UNet3DConditionModel
58
+ from .pipeline_controlnet import (
59
+ MusevControlNetPipeline,
60
+ VideoPipelineOutput as PipelineVideoPipelineOutput,
61
+ )
62
+ from ..utils.util import save_videos_grid_with_opencv
63
+ from ..utils.model_util import (
64
+ update_pipeline_basemodel,
65
+ update_pipeline_lora_model,
66
+ update_pipeline_lora_models,
67
+ update_pipeline_model_parameters,
68
+ )
69
+
70
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
71
+
72
+
73
+ @dataclass
74
+ class VideoPipelineOutput(BaseOutput):
75
+ videos: Union[torch.Tensor, np.ndarray]
76
+ latents: Union[torch.Tensor, np.ndarray]
77
+ videos_mid: Union[torch.Tensor, np.ndarray]
78
+ controlnet_cond: Union[torch.Tensor, np.ndarray]
79
+ generated_videos: Union[torch.Tensor, np.ndarray]
80
+
81
+
82
+ def update_controlnet_processor_params(
83
+ src: Union[Dict, List[Dict]], dst: Union[Dict, List[Dict]]
84
+ ):
85
+ """merge dst into src"""
86
+ if isinstance(src, list) and not isinstance(dst, List):
87
+ dst = [dst] * len(src)
88
+ if isinstance(src, list) and isinstance(dst, list):
89
+ return [
90
+ update_controlnet_processor_params(src[i], dst[i]) for i in range(len(src))
91
+ ]
92
+ if src is None:
93
+ dct = {}
94
+ else:
95
+ dct = copy.deepcopy(src)
96
+ if dst is None:
97
+ dst = {}
98
+ dct.update(dst)
99
+ return dct
100
+
101
+
102
+ class DiffusersPipelinePredictor(object):
103
+ """wraper of diffusers pipeline, support generation function interface. support
104
+ 1. text2video: inputs include text, image(optional), refer_image(optional)
105
+ 2. video2video:
106
+ 1. use controlnet to control spatial
107
+ 2. or use video fuse noise to denoise
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ sd_model_path: str,
113
+ unet: nn.Module,
114
+ controlnet_name: Union[str, List[str]] = None,
115
+ controlnet: nn.Module = None,
116
+ lora_dict: Dict[str, Dict] = None,
117
+ requires_safety_checker: bool = False,
118
+ device: str = "cuda",
119
+ dtype: torch.dtype = torch.float16,
120
+ # controlnet parameters start
121
+ need_controlnet_processor: bool = True,
122
+ need_controlnet: bool = True,
123
+ image_resolution: int = 512,
124
+ detect_resolution: int = 512,
125
+ include_body: bool = True,
126
+ hand_and_face: bool = None,
127
+ include_face: bool = False,
128
+ include_hand: bool = True,
129
+ negative_embedding: List = None,
130
+ # controlnet parameters end
131
+ enable_xformers_memory_efficient_attention: bool = True,
132
+ lcm_lora_dct: Dict = None,
133
+ referencenet: nn.Module = None,
134
+ ip_adapter_image_proj: nn.Module = None,
135
+ vision_clip_extractor: nn.Module = None,
136
+ face_emb_extractor: nn.Module = None,
137
+ facein_image_proj: nn.Module = None,
138
+ ip_adapter_face_emb_extractor: nn.Module = None,
139
+ ip_adapter_face_image_proj: nn.Module = None,
140
+ vae_model: Optional[Tuple[nn.Module, str]] = None,
141
+ pose_guider: Optional[nn.Module] = None,
142
+ enable_zero_snr: bool = False,
143
+ ) -> None:
144
+ self.sd_model_path = sd_model_path
145
+ self.unet = unet
146
+ self.controlnet_name = controlnet_name
147
+ self.controlnet = controlnet
148
+ self.requires_safety_checker = requires_safety_checker
149
+ self.device = device
150
+ self.dtype = dtype
151
+ self.need_controlnet_processor = need_controlnet_processor
152
+ self.need_controlnet = need_controlnet
153
+ self.need_controlnet_processor = need_controlnet_processor
154
+ self.image_resolution = image_resolution
155
+ self.detect_resolution = detect_resolution
156
+ self.include_body = include_body
157
+ self.hand_and_face = hand_and_face
158
+ self.include_face = include_face
159
+ self.include_hand = include_hand
160
+ self.negative_embedding = negative_embedding
161
+ self.device = device
162
+ self.dtype = dtype
163
+ self.lcm_lora_dct = lcm_lora_dct
164
+ if controlnet is None and controlnet_name is not None:
165
+ controlnet, controlnet_processor, processor_params = load_controlnet_model(
166
+ controlnet_name,
167
+ device=device,
168
+ dtype=dtype,
169
+ need_controlnet_processor=need_controlnet_processor,
170
+ need_controlnet=need_controlnet,
171
+ image_resolution=image_resolution,
172
+ detect_resolution=detect_resolution,
173
+ include_body=include_body,
174
+ include_face=include_face,
175
+ hand_and_face=hand_and_face,
176
+ include_hand=include_hand,
177
+ )
178
+ self.controlnet_processor = controlnet_processor
179
+ self.controlnet_processor_params = processor_params
180
+ logger.debug(f"init controlnet controlnet_name={controlnet_name}")
181
+
182
+ if controlnet is not None:
183
+ controlnet = controlnet.to(device=device, dtype=dtype)
184
+ controlnet.eval()
185
+ if pose_guider is not None:
186
+ pose_guider = pose_guider.to(device=device, dtype=dtype)
187
+ pose_guider.eval()
188
+ unet.to(device=device, dtype=dtype)
189
+ unet.eval()
190
+ if referencenet is not None:
191
+ referencenet.to(device=device, dtype=dtype)
192
+ referencenet.eval()
193
+ if ip_adapter_image_proj is not None:
194
+ ip_adapter_image_proj.to(device=device, dtype=dtype)
195
+ ip_adapter_image_proj.eval()
196
+ if vision_clip_extractor is not None:
197
+ vision_clip_extractor.to(device=device, dtype=dtype)
198
+ vision_clip_extractor.eval()
199
+ if face_emb_extractor is not None:
200
+ face_emb_extractor.to(device=device, dtype=dtype)
201
+ face_emb_extractor.eval()
202
+ if facein_image_proj is not None:
203
+ facein_image_proj.to(device=device, dtype=dtype)
204
+ facein_image_proj.eval()
205
+
206
+ if isinstance(vae_model, str):
207
+ # TODO: poor implementation, to improve
208
+ if "consistency" in vae_model:
209
+ vae = ConsistencyDecoderVAE.from_pretrained(vae_model)
210
+ else:
211
+ vae = AutoencoderKL.from_pretrained(vae_model)
212
+ elif isinstance(vae_model, nn.Module):
213
+ vae = vae_model
214
+ else:
215
+ vae = None
216
+ if vae is not None:
217
+ vae.to(device=device, dtype=dtype)
218
+ vae.eval()
219
+ if ip_adapter_face_emb_extractor is not None:
220
+ ip_adapter_face_emb_extractor.to(device=device, dtype=dtype)
221
+ ip_adapter_face_emb_extractor.eval()
222
+ if ip_adapter_face_image_proj is not None:
223
+ ip_adapter_face_image_proj.to(device=device, dtype=dtype)
224
+ ip_adapter_face_image_proj.eval()
225
+ params = {
226
+ "pretrained_model_name_or_path": sd_model_path,
227
+ "controlnet": controlnet,
228
+ "unet": unet,
229
+ "requires_safety_checker": requires_safety_checker,
230
+ "torch_dtype": dtype,
231
+ "torch_device": device,
232
+ "referencenet": referencenet,
233
+ "ip_adapter_image_proj": ip_adapter_image_proj,
234
+ "vision_clip_extractor": vision_clip_extractor,
235
+ "facein_image_proj": facein_image_proj,
236
+ "face_emb_extractor": face_emb_extractor,
237
+ "ip_adapter_face_emb_extractor": ip_adapter_face_emb_extractor,
238
+ "ip_adapter_face_image_proj": ip_adapter_face_image_proj,
239
+ "pose_guider": pose_guider,
240
+ }
241
+ if vae is not None:
242
+ params["vae"] = vae
243
+ pipeline = MusevControlNetPipeline.from_pretrained(**params)
244
+ pipeline = pipeline.to(torch_device=device, torch_dtype=dtype)
245
+ logger.debug(
246
+ f"init pipeline from sd_model_path={sd_model_path}, device={device}, dtype={dtype}"
247
+ )
248
+ if (
249
+ negative_embedding is not None
250
+ and pipeline.text_encoder is not None
251
+ and pipeline.tokenizer is not None
252
+ ):
253
+ for neg_emb_path, neg_token in negative_embedding:
254
+ pipeline.load_textual_inversion(neg_emb_path, token=neg_token)
255
+
256
+ # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
257
+ # pipe.enable_model_cpu_offload()
258
+ if not enable_zero_snr:
259
+ pipeline.scheduler = EulerDiscreteScheduler.from_config(
260
+ pipeline.scheduler.config
261
+ )
262
+ # pipeline.scheduler = DDIMScheduler.from_config(
263
+ # pipeline.scheduler.config,
264
+ # 该部分会影响生成视频的亮度,不适用于首���给定的视频生成
265
+ # this part will change brightness of video, not suitable for image2video mode
266
+ # rescale_betas_zero_snr affect the brightness of the generated video, not suitable for vision condition images mode
267
+ # # rescale_betas_zero_snr=True,
268
+ # )
269
+ # pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
270
+ else:
271
+ # moore scheduler, just for codetest
272
+ pipeline.scheduler = DDIMScheduler(
273
+ beta_start=0.00085,
274
+ beta_end=0.012,
275
+ beta_schedule="linear",
276
+ clip_sample=False,
277
+ steps_offset=1,
278
+ ### Zero-SNR params
279
+ prediction_type="v_prediction",
280
+ rescale_betas_zero_snr=True,
281
+ timestep_spacing="trailing",
282
+ )
283
+
284
+ pipeline.enable_vae_slicing()
285
+ self.enable_xformers_memory_efficient_attention = (
286
+ enable_xformers_memory_efficient_attention
287
+ )
288
+ if enable_xformers_memory_efficient_attention:
289
+ if is_xformers_available():
290
+ pipeline.enable_xformers_memory_efficient_attention()
291
+ else:
292
+ raise ValueError(
293
+ "xformers is not available. Make sure it is installed correctly"
294
+ )
295
+ self.pipeline = pipeline
296
+ self.unload_dict = [] # keep lora state
297
+ if lora_dict is not None:
298
+ self.load_lora(lora_dict=lora_dict)
299
+ logger.debug("load lora {}".format(" ".join(list(lora_dict.keys()))))
300
+
301
+ if lcm_lora_dct is not None:
302
+ self.pipeline.scheduler = LCMScheduler.from_config(
303
+ self.pipeline.scheduler.config
304
+ )
305
+ self.load_lora(lora_dict=lcm_lora_dct)
306
+ logger.debug("load lcm lora {}".format(" ".join(list(lcm_lora_dct.keys()))))
307
+
308
+ # logger.debug("Unet3Model Parameters")
309
+ # logger.debug(pformat(self.__dict__))
310
+
311
+ def load_lora(
312
+ self,
313
+ lora_dict: Dict[str, Dict],
314
+ ):
315
+ self.pipeline, unload_dict = update_pipeline_lora_models(
316
+ self.pipeline, lora_dict, device=self.device
317
+ )
318
+ self.unload_dict += unload_dict
319
+
320
+ def unload_lora(self):
321
+ for layer_data in self.unload_dict:
322
+ layer = layer_data["layer"]
323
+ added_weight = layer_data["added_weight"]
324
+ layer.weight.data -= added_weight
325
+ self.unload_dict = []
326
+ gc.collect()
327
+ torch.cuda.empty_cache()
328
+
329
+ def update_unet(self, unet: nn.Module):
330
+ self.pipeline.unet = unet.to(device=self.device, dtype=self.dtype)
331
+
332
+ def update_sd_model(self, model_path: str, text_model_path: str):
333
+ self.pipeline = update_pipeline_basemodel(
334
+ self.pipeline,
335
+ model_path,
336
+ text_sd_model_path=text_model_path,
337
+ device=self.device,
338
+ )
339
+
340
+ def update_sd_model_and_unet(
341
+ self, lora_sd_path: str, lora_path: str, sd_model_path: str = None
342
+ ):
343
+ self.pipeline = update_pipeline_model_parameters(
344
+ self.pipeline,
345
+ model_path=lora_sd_path,
346
+ lora_path=lora_path,
347
+ text_model_path=sd_model_path,
348
+ device=self.device,
349
+ )
350
+
351
+ def update_controlnet(self, controlnet_name=Union[str, List[str]]):
352
+ self.pipeline.controlnet = load_controlnet_model(controlnet_name).to(
353
+ device=self.device, dtype=self.dtype
354
+ )
355
+
356
+ def run_pipe_text2video(
357
+ self,
358
+ video_length: int,
359
+ prompt: Union[str, List[str]] = None,
360
+ # b c t h w
361
+ height: Optional[int] = None,
362
+ width: Optional[int] = None,
363
+ video_num_inference_steps: int = 50,
364
+ video_guidance_scale: float = 7.5,
365
+ video_guidance_scale_end: float = 3.5,
366
+ video_guidance_scale_method: str = "linear",
367
+ strength: float = 0.8,
368
+ video_negative_prompt: Optional[Union[str, List[str]]] = None,
369
+ negative_prompt: Optional[Union[str, List[str]]] = None,
370
+ num_videos_per_prompt: Optional[int] = 1,
371
+ eta: float = 0.0,
372
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
373
+ same_seed: Optional[Union[int, List[int]]] = None,
374
+ # b c t(1) ho wo
375
+ condition_latents: Optional[torch.FloatTensor] = None,
376
+ latents: Optional[torch.FloatTensor] = None,
377
+ prompt_embeds: Optional[torch.FloatTensor] = None,
378
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
379
+ guidance_scale: float = 7.5,
380
+ num_inference_steps: int = 50,
381
+ output_type: Optional[str] = "tensor",
382
+ return_dict: bool = True,
383
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
384
+ callback_steps: int = 1,
385
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
386
+ need_middle_latents: bool = False,
387
+ w_ind_noise: float = 0.5,
388
+ initial_common_latent: Optional[torch.FloatTensor] = None,
389
+ latent_index: torch.LongTensor = None,
390
+ vision_condition_latent_index: torch.LongTensor = None,
391
+ n_vision_condition: int = 1,
392
+ noise_type: str = "random",
393
+ max_batch_num: int = 30,
394
+ need_img_based_video_noise: bool = False,
395
+ condition_images: torch.Tensor = None,
396
+ fix_condition_images: bool = False,
397
+ redraw_condition_image: bool = False,
398
+ img_weight: float = 1e-3,
399
+ motion_speed: float = 8.0,
400
+ need_hist_match: bool = False,
401
+ refer_image: Optional[
402
+ Tuple[np.ndarray, torch.Tensor, List[str], List[np.ndarray]]
403
+ ] = None,
404
+ ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None,
405
+ fixed_refer_image: bool = True,
406
+ fixed_ip_adapter_image: bool = True,
407
+ redraw_condition_image_with_ipdapter: bool = True,
408
+ redraw_condition_image_with_referencenet: bool = True,
409
+ refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None,
410
+ fixed_refer_face_image: bool = True,
411
+ redraw_condition_image_with_facein: bool = True,
412
+ ip_adapter_scale: float = 1.0,
413
+ redraw_condition_image_with_ip_adapter_face: bool = True,
414
+ facein_scale: float = 1.0,
415
+ ip_adapter_face_scale: float = 1.0,
416
+ prompt_only_use_image_prompt: bool = False,
417
+ # serial_denoise parameter start
418
+ record_mid_video_noises: bool = False,
419
+ record_mid_video_latents: bool = False,
420
+ video_overlap: int = 1,
421
+ # serial_denoise parameter end
422
+ # parallel_denoise parameter start
423
+ context_schedule="uniform",
424
+ context_frames=12,
425
+ context_stride=1,
426
+ context_overlap=4,
427
+ context_batch_size=1,
428
+ interpolation_factor=1,
429
+ # parallel_denoise parameter end
430
+ ):
431
+ """
432
+ generate long video with end2end mode
433
+ 1. prepare vision condition image by assingning, redraw, or generation with text2image module with skip_temporal_layer=True;
434
+ 2. use image or latest of vision condition image to generate first shot;
435
+ 3. use last n (1) image or last latent of last shot as new vision condition latent to generate next shot
436
+ 4. repeat n_batch times between 2 and 3
437
+
438
+ 类似img2img pipeline
439
+ refer_image和ip_adapter_image的来源:
440
+ 1. 输入给定;
441
+ 2. 当未输入时,纯text2video生成首帧,并赋值更新refer_image和ip_adapter_image;
442
+ 3. 当有输入,但是因为redraw更新了首帧时,也需要赋值更新refer_image和ip_adapter_image;
443
+
444
+ refer_image和ip_adapter_image的作用:
445
+ 1. 当无首帧图像时,用于生成首帧;
446
+ 2. 用于生成视频。
447
+
448
+
449
+ similar to diffusers img2img pipeline.
450
+ three ways to prepare refer_image and ip_adapter_image
451
+ 1. from input parameter
452
+ 2. when input paramter is None, use text2video to generate vis cond image, and use as refer_image and ip_adapter_image too.
453
+ 3. given from input paramter, but still redraw, update with redrawn vis cond image.
454
+ """
455
+ # crop resize images
456
+ if condition_images is not None:
457
+ logger.debug(
458
+ f"center crop resize condition_images={condition_images.shape}, to height={height}, width={width}"
459
+ )
460
+ condition_images = batch_dynamic_crop_resize_images_v2(
461
+ condition_images,
462
+ target_height=height,
463
+ target_width=width,
464
+ )
465
+ if refer_image is not None:
466
+ logger.debug(
467
+ f"center crop resize refer_image to height={height}, width={width}"
468
+ )
469
+ refer_image = batch_dynamic_crop_resize_images_v2(
470
+ refer_image,
471
+ target_height=height,
472
+ target_width=width,
473
+ )
474
+ if ip_adapter_image is not None:
475
+ logger.debug(
476
+ f"center crop resize ip_adapter_image to height={height}, width={width}"
477
+ )
478
+ ip_adapter_image = batch_dynamic_crop_resize_images_v2(
479
+ ip_adapter_image,
480
+ target_height=height,
481
+ target_width=width,
482
+ )
483
+ if refer_face_image is not None:
484
+ logger.debug(
485
+ f"center crop resize refer_face_image to height={height}, width={width}"
486
+ )
487
+ refer_face_image = batch_dynamic_crop_resize_images_v2(
488
+ refer_face_image,
489
+ target_height=height,
490
+ target_width=width,
491
+ )
492
+ run_video_length = video_length
493
+ # generate vision condition frame start
494
+ # if condition_images is None, generate with refer_image, ip_adapter_image
495
+ # if condition_images not None and need redraw, according to redraw_condition_image_with_ipdapter, redraw_condition_image_with_referencenet, refer_image, ip_adapter_image
496
+ if n_vision_condition > 0:
497
+ if condition_images is None and condition_latents is None:
498
+ logger.debug("run_pipe_text2video, generate first_image")
499
+ (
500
+ condition_images,
501
+ condition_latents,
502
+ _,
503
+ _,
504
+ _,
505
+ ) = self.pipeline(
506
+ prompt=prompt,
507
+ num_inference_steps=num_inference_steps,
508
+ guidance_scale=guidance_scale,
509
+ negative_prompt=negative_prompt,
510
+ video_length=1,
511
+ height=height,
512
+ width=width,
513
+ return_dict=False,
514
+ skip_temporal_layer=True,
515
+ output_type="np",
516
+ generator=generator,
517
+ w_ind_noise=w_ind_noise,
518
+ need_img_based_video_noise=need_img_based_video_noise,
519
+ refer_image=refer_image
520
+ if redraw_condition_image_with_referencenet
521
+ else None,
522
+ ip_adapter_image=ip_adapter_image
523
+ if redraw_condition_image_with_ipdapter
524
+ else None,
525
+ refer_face_image=refer_face_image
526
+ if redraw_condition_image_with_facein
527
+ else None,
528
+ ip_adapter_scale=ip_adapter_scale,
529
+ facein_scale=facein_scale,
530
+ ip_adapter_face_scale=ip_adapter_face_scale,
531
+ ip_adapter_face_image=refer_face_image
532
+ if redraw_condition_image_with_ip_adapter_face
533
+ else None,
534
+ prompt_only_use_image_prompt=prompt_only_use_image_prompt,
535
+ )
536
+ run_video_length = video_length - 1
537
+ elif (
538
+ condition_images is not None
539
+ and redraw_condition_image
540
+ and condition_latents is None
541
+ ):
542
+ logger.debug("run_pipe_text2video, redraw first_image")
543
+
544
+ (
545
+ condition_images,
546
+ condition_latents,
547
+ _,
548
+ _,
549
+ _,
550
+ ) = self.pipeline(
551
+ prompt=prompt,
552
+ image=condition_images,
553
+ num_inference_steps=num_inference_steps,
554
+ guidance_scale=guidance_scale,
555
+ negative_prompt=negative_prompt,
556
+ strength=strength,
557
+ video_length=condition_images.shape[2],
558
+ height=height,
559
+ width=width,
560
+ return_dict=False,
561
+ skip_temporal_layer=True,
562
+ output_type="np",
563
+ generator=generator,
564
+ w_ind_noise=w_ind_noise,
565
+ need_img_based_video_noise=need_img_based_video_noise,
566
+ refer_image=refer_image
567
+ if redraw_condition_image_with_referencenet
568
+ else None,
569
+ ip_adapter_image=ip_adapter_image
570
+ if redraw_condition_image_with_ipdapter
571
+ else None,
572
+ refer_face_image=refer_face_image
573
+ if redraw_condition_image_with_facein
574
+ else None,
575
+ ip_adapter_scale=ip_adapter_scale,
576
+ facein_scale=facein_scale,
577
+ ip_adapter_face_scale=ip_adapter_face_scale,
578
+ ip_adapter_face_image=refer_face_image
579
+ if redraw_condition_image_with_ip_adapter_face
580
+ else None,
581
+ prompt_only_use_image_prompt=prompt_only_use_image_prompt,
582
+ )
583
+ else:
584
+ condition_images = None
585
+ condition_latents = None
586
+ # generate vision condition frame end
587
+
588
+ # refer_image and ip_adapter_image, update mode from 2 and 3 as mentioned above start
589
+ if (
590
+ refer_image is not None
591
+ and redraw_condition_image
592
+ and condition_images is not None
593
+ ):
594
+ refer_image = condition_images * 255.0
595
+ logger.debug(f"update refer_image because of redraw_condition_image")
596
+ elif (
597
+ refer_image is None
598
+ and self.pipeline.referencenet is not None
599
+ and condition_images is not None
600
+ ):
601
+ refer_image = condition_images * 255.0
602
+ logger.debug(f"update refer_image because of generate first_image")
603
+
604
+ # ipadapter_image
605
+ if (
606
+ ip_adapter_image is not None
607
+ and redraw_condition_image
608
+ and condition_images is not None
609
+ ):
610
+ ip_adapter_image = condition_images * 255.0
611
+ logger.debug(f"update ip_adapter_image because of redraw_condition_image")
612
+ elif (
613
+ ip_adapter_image is None
614
+ and self.pipeline.ip_adapter_image_proj is not None
615
+ and condition_images is not None
616
+ ):
617
+ ip_adapter_image = condition_images * 255.0
618
+ logger.debug(f"update ip_adapter_image because of generate first_image")
619
+ # refer_image and ip_adapter_image, update mode from 2 and 3 as mentioned above end
620
+
621
+ # refer_face_image, update mode from 2 and 3 as mentioned above start
622
+ if (
623
+ refer_face_image is not None
624
+ and redraw_condition_image
625
+ and condition_images is not None
626
+ ):
627
+ refer_face_image = condition_images * 255.0
628
+ logger.debug(f"update refer_face_image because of redraw_condition_image")
629
+ elif (
630
+ refer_face_image is None
631
+ and self.pipeline.facein_image_proj is not None
632
+ and condition_images is not None
633
+ ):
634
+ refer_face_image = condition_images * 255.0
635
+ logger.debug(f"update face_image because of generate first_image")
636
+ # refer_face_image, update mode from 2 and 3 as mentioned above end
637
+
638
+ last_mid_video_noises = None
639
+ last_mid_video_latents = None
640
+ initial_common_latent = None
641
+
642
+ out_videos = []
643
+ for i_batch in range(max_batch_num):
644
+ logger.debug(f"sd_pipeline_predictor, run_pipe_text2video: {i_batch}")
645
+ if max_batch_num is not None and i_batch == max_batch_num:
646
+ break
647
+
648
+ if i_batch == 0:
649
+ result_overlap = 0
650
+ else:
651
+ if n_vision_condition > 0:
652
+ # ignore condition_images if condition_latents is not None in pipeline
653
+ if not fix_condition_images:
654
+ logger.debug(f"{i_batch}, update condition_latents")
655
+ condition_latents = out_latents_batch[
656
+ :, :, -n_vision_condition:, :, :
657
+ ]
658
+ else:
659
+ logger.debug(f"{i_batch}, do not update condition_latents")
660
+ result_overlap = n_vision_condition
661
+
662
+ if not fixed_refer_image and n_vision_condition > 0:
663
+ logger.debug("ref_image use last frame of last generated out video")
664
+ refer_image = out_batch[:, :, -n_vision_condition:, :, :] * 255.0
665
+ else:
666
+ logger.debug("use given fixed ref_image")
667
+
668
+ if not fixed_ip_adapter_image and n_vision_condition > 0:
669
+ logger.debug(
670
+ "ip_adapter_image use last frame of last generated out video"
671
+ )
672
+ ip_adapter_image = (
673
+ out_batch[:, :, -n_vision_condition:, :, :] * 255.0
674
+ )
675
+ else:
676
+ logger.debug("use given fixed ip_adapter_image")
677
+
678
+ if not fixed_refer_face_image and n_vision_condition > 0:
679
+ logger.debug(
680
+ "refer_face_image use last frame of last generated out video"
681
+ )
682
+ refer_face_image = (
683
+ out_batch[:, :, -n_vision_condition:, :, :] * 255.0
684
+ )
685
+ else:
686
+ logger.debug("use given fixed ip_adapter_image")
687
+
688
+ run_video_length = video_length
689
+ if same_seed is not None:
690
+ _, generator = set_all_seed(same_seed)
691
+
692
+ out = self.pipeline(
693
+ video_length=run_video_length, # int
694
+ prompt=prompt,
695
+ num_inference_steps=video_num_inference_steps,
696
+ height=height,
697
+ width=width,
698
+ generator=generator,
699
+ condition_images=condition_images,
700
+ condition_latents=condition_latents, # b co t(1) ho wo
701
+ skip_temporal_layer=False,
702
+ output_type="np",
703
+ noise_type=noise_type,
704
+ negative_prompt=video_negative_prompt,
705
+ guidance_scale=video_guidance_scale,
706
+ guidance_scale_end=video_guidance_scale_end,
707
+ guidance_scale_method=video_guidance_scale_method,
708
+ w_ind_noise=w_ind_noise,
709
+ need_img_based_video_noise=need_img_based_video_noise,
710
+ img_weight=img_weight,
711
+ motion_speed=motion_speed,
712
+ vision_condition_latent_index=vision_condition_latent_index,
713
+ refer_image=refer_image,
714
+ ip_adapter_image=ip_adapter_image,
715
+ refer_face_image=refer_face_image,
716
+ ip_adapter_scale=ip_adapter_scale,
717
+ facein_scale=facein_scale,
718
+ ip_adapter_face_scale=ip_adapter_face_scale,
719
+ ip_adapter_face_image=refer_face_image,
720
+ prompt_only_use_image_prompt=prompt_only_use_image_prompt,
721
+ initial_common_latent=initial_common_latent,
722
+ # serial_denoise parameter start
723
+ record_mid_video_noises=record_mid_video_noises,
724
+ last_mid_video_noises=last_mid_video_noises,
725
+ record_mid_video_latents=record_mid_video_latents,
726
+ last_mid_video_latents=last_mid_video_latents,
727
+ video_overlap=video_overlap,
728
+ # serial_denoise parameter end
729
+ # parallel_denoise parameter start
730
+ context_schedule=context_schedule,
731
+ context_frames=context_frames,
732
+ context_stride=context_stride,
733
+ context_overlap=context_overlap,
734
+ context_batch_size=context_batch_size,
735
+ interpolation_factor=interpolation_factor,
736
+ # parallel_denoise parameter end
737
+ )
738
+ logger.debug(
739
+ f"run_pipe_text2video, out.videos.shape, i_batch={i_batch}, videos={out.videos.shape}, result_overlap={result_overlap}"
740
+ )
741
+ out_batch = out.videos[:, :, result_overlap:, :, :]
742
+ out_latents_batch = out.latents[:, :, result_overlap:, :, :]
743
+ out_videos.append(out_batch)
744
+
745
+ out_videos = np.concatenate(out_videos, axis=2)
746
+ if need_hist_match:
747
+ out_videos[:, :, 1:, :, :] = hist_match_video_bcthw(
748
+ out_videos[:, :, 1:, :, :], out_videos[:, :, :1, :, :], value=255.0
749
+ )
750
+ return out_videos
751
+
752
+ def run_pipe_with_latent_input(
753
+ self,
754
+ ):
755
+ pass
756
+
757
+ def run_pipe_middle2video_with_middle(self, middle: Tuple[str, Iterable]):
758
+ pass
759
+
760
+ def run_pipe_video2video(
761
+ self,
762
+ video: Tuple[str, Iterable],
763
+ time_size: int = None,
764
+ sample_rate: int = None,
765
+ overlap: int = None,
766
+ step: int = None,
767
+ prompt: Union[str, List[str]] = None,
768
+ # b c t h w
769
+ height: Optional[int] = None,
770
+ width: Optional[int] = None,
771
+ num_inference_steps: int = 50,
772
+ video_num_inference_steps: int = 50,
773
+ guidance_scale: float = 7.5,
774
+ video_guidance_scale: float = 7.5,
775
+ video_guidance_scale_end: float = 3.5,
776
+ video_guidance_scale_method: str = "linear",
777
+ video_negative_prompt: Optional[Union[str, List[str]]] = None,
778
+ negative_prompt: Optional[Union[str, List[str]]] = None,
779
+ num_videos_per_prompt: Optional[int] = 1,
780
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
781
+ eta: float = 0.0,
782
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
783
+ controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None,
784
+ # b c t(1) hi wi
785
+ controlnet_condition_images: Optional[torch.FloatTensor] = None,
786
+ # b c t(1) ho wo
787
+ controlnet_condition_latents: Optional[torch.FloatTensor] = None,
788
+ # b c t(1) ho wo
789
+ condition_latents: Optional[torch.FloatTensor] = None,
790
+ condition_images: Optional[torch.FloatTensor] = None,
791
+ fix_condition_images: bool = False,
792
+ latents: Optional[torch.FloatTensor] = None,
793
+ prompt_embeds: Optional[torch.FloatTensor] = None,
794
+ output_type: Optional[str] = "tensor",
795
+ return_dict: bool = True,
796
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
797
+ callback_steps: int = 1,
798
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
799
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
800
+ guess_mode: bool = False,
801
+ control_guidance_start: Union[float, List[float]] = 0.0,
802
+ control_guidance_end: Union[float, List[float]] = 1.0,
803
+ need_middle_latents: bool = False,
804
+ w_ind_noise: float = 0.5,
805
+ img_weight: float = 0.001,
806
+ initial_common_latent: Optional[torch.FloatTensor] = None,
807
+ latent_index: torch.LongTensor = None,
808
+ vision_condition_latent_index: torch.LongTensor = None,
809
+ noise_type: str = "random",
810
+ controlnet_processor_params: Dict = None,
811
+ need_return_videos: bool = False,
812
+ need_return_condition: bool = False,
813
+ max_batch_num: int = 30,
814
+ strength: float = 0.8,
815
+ video_strength: float = 0.8,
816
+ need_video2video: bool = False,
817
+ need_img_based_video_noise: bool = False,
818
+ need_hist_match: bool = False,
819
+ end_to_end: bool = True,
820
+ refer_image: Optional[
821
+ Tuple[np.ndarray, torch.Tensor, List[str], List[np.ndarray]]
822
+ ] = None,
823
+ ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None,
824
+ fixed_refer_image: bool = True,
825
+ fixed_ip_adapter_image: bool = True,
826
+ redraw_condition_image: bool = False,
827
+ redraw_condition_image_with_ipdapter: bool = True,
828
+ redraw_condition_image_with_referencenet: bool = True,
829
+ refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None,
830
+ fixed_refer_face_image: bool = True,
831
+ redraw_condition_image_with_facein: bool = True,
832
+ ip_adapter_scale: float = 1.0,
833
+ facein_scale: float = 1.0,
834
+ ip_adapter_face_scale: float = 1.0,
835
+ redraw_condition_image_with_ip_adapter_face: bool = False,
836
+ n_vision_condition: int = 1,
837
+ prompt_only_use_image_prompt: bool = False,
838
+ motion_speed: float = 8.0,
839
+ # serial_denoise parameter start
840
+ record_mid_video_noises: bool = False,
841
+ record_mid_video_latents: bool = False,
842
+ video_overlap: int = 1,
843
+ # serial_denoise parameter end
844
+ # parallel_denoise parameter start
845
+ context_schedule="uniform",
846
+ context_frames=12,
847
+ context_stride=1,
848
+ context_overlap=4,
849
+ context_batch_size=1,
850
+ interpolation_factor=1,
851
+ # parallel_denoise parameter end
852
+ # 支持 video_path 时多种输入
853
+ # TODO:// video_has_condition =False,当且仅支持 video_is_middle=True, 待后续重构
854
+ # TODO:// when video_has_condition =False, video_is_middle should be True.
855
+ video_is_middle: bool = False,
856
+ video_has_condition: bool = True,
857
+ ):
858
+ """
859
+ 类似controlnet text2img pipeline。 输入视频,用视频得到controlnet condition。
860
+ 目前仅支持time_size == step,overlap=0
861
+ 输出视频长度=输入视频长度
862
+
863
+ similar to controlnet text2image pipeline, generate video with controlnet condition from given video.
864
+ By now, sliding window only support time_size == step, overlap = 0.
865
+ """
866
+ if isinstance(video, str):
867
+ video_reader = DecordVideoDataset(
868
+ video,
869
+ time_size=time_size,
870
+ step=step,
871
+ overlap=overlap,
872
+ sample_rate=sample_rate,
873
+ device="cpu",
874
+ data_type="rgb",
875
+ channels_order="c t h w",
876
+ drop_last=True,
877
+ )
878
+ else:
879
+ video_reader = video
880
+ videos = [] if need_return_videos else None
881
+ out_videos = []
882
+ out_condition = (
883
+ []
884
+ if need_return_condition and self.pipeline.controlnet is not None
885
+ else None
886
+ )
887
+ # crop resize images
888
+ if condition_images is not None:
889
+ logger.debug(
890
+ f"center crop resize condition_images={condition_images.shape}, to height={height}, width={width}"
891
+ )
892
+ condition_images = batch_dynamic_crop_resize_images_v2(
893
+ condition_images,
894
+ target_height=height,
895
+ target_width=width,
896
+ )
897
+ if refer_image is not None:
898
+ logger.debug(
899
+ f"center crop resize refer_image to height={height}, width={width}"
900
+ )
901
+ refer_image = batch_dynamic_crop_resize_images_v2(
902
+ refer_image,
903
+ target_height=height,
904
+ target_width=width,
905
+ )
906
+ if ip_adapter_image is not None:
907
+ logger.debug(
908
+ f"center crop resize ip_adapter_image to height={height}, width={width}"
909
+ )
910
+ ip_adapter_image = batch_dynamic_crop_resize_images_v2(
911
+ ip_adapter_image,
912
+ target_height=height,
913
+ target_width=width,
914
+ )
915
+ if refer_face_image is not None:
916
+ logger.debug(
917
+ f"center crop resize refer_face_image to height={height}, width={width}"
918
+ )
919
+ refer_face_image = batch_dynamic_crop_resize_images_v2(
920
+ refer_face_image,
921
+ target_height=height,
922
+ target_width=width,
923
+ )
924
+ first_image = None
925
+ last_mid_video_noises = None
926
+ last_mid_video_latents = None
927
+ initial_common_latent = None
928
+ # initial_common_latent = torch.randn((1, 4, 1, 112, 64)).to(
929
+ # device=self.device, dtype=self.dtype
930
+ # )
931
+
932
+ for i_batch, item in enumerate(video_reader):
933
+ logger.debug(f"\n sd_pipeline_predictor, run_pipe_video2video: {i_batch}")
934
+ if max_batch_num is not None and i_batch == max_batch_num:
935
+ break
936
+ # read and prepare video batch
937
+ batch = item.data
938
+ batch = batch_dynamic_crop_resize_images(
939
+ batch,
940
+ target_height=height,
941
+ target_width=width,
942
+ )
943
+
944
+ batch = batch[np.newaxis, ...]
945
+ batch_size, channel, video_length, video_height, video_width = batch.shape
946
+ # extract controlnet middle
947
+ if self.pipeline.controlnet is not None:
948
+ batch = rearrange(batch, "b c t h w-> (b t) h w c")
949
+ controlnet_processor_params = update_controlnet_processor_params(
950
+ src=self.controlnet_processor_params,
951
+ dst=controlnet_processor_params,
952
+ )
953
+ if not video_is_middle:
954
+ batch_condition = self.controlnet_processor(
955
+ data=batch,
956
+ data_channel_order="b h w c",
957
+ target_height=height,
958
+ target_width=width,
959
+ return_type="np",
960
+ return_data_channel_order="b c h w",
961
+ input_rgb_order="rgb",
962
+ processor_params=controlnet_processor_params,
963
+ )
964
+ else:
965
+ # TODO: 临时用于可视化输入的 controlnet middle 序列,后续待拆到 middl2video中,也可以增加参数支持
966
+ # TODO: only use video_path is controlnet middle output, to improved
967
+ batch_condition = rearrange(
968
+ copy.deepcopy(batch), " b h w c-> b c h w"
969
+ )
970
+
971
+ # 当前仅当 输入是 middle、condition_image的pose在middle首帧之前,需要重新生成condition_images的pose并绑定到middle_batch上
972
+ # when video_path is middle seq and condition_image is not aligned with middle seq,
973
+ # regenerate codntion_images pose, and then concat into middle_batch,
974
+ if (
975
+ i_batch == 0
976
+ and not video_has_condition
977
+ and video_is_middle
978
+ and condition_images is not None
979
+ ):
980
+ condition_images_reshape = rearrange(
981
+ condition_images, "b c t h w-> (b t) h w c"
982
+ )
983
+ condition_images_condition = self.controlnet_processor(
984
+ data=condition_images_reshape,
985
+ data_channel_order="b h w c",
986
+ target_height=height,
987
+ target_width=width,
988
+ return_type="np",
989
+ return_data_channel_order="b c h w",
990
+ input_rgb_order="rgb",
991
+ processor_params=controlnet_processor_params,
992
+ )
993
+ condition_images_condition = rearrange(
994
+ condition_images_condition,
995
+ "(b t) c h w-> b c t h w",
996
+ b=batch_size,
997
+ )
998
+ else:
999
+ condition_images_condition = None
1000
+ if not isinstance(batch_condition, list):
1001
+ batch_condition = rearrange(
1002
+ batch_condition, "(b t) c h w-> b c t h w", b=batch_size
1003
+ )
1004
+ if condition_images_condition is not None:
1005
+ batch_condition = np.concatenate(
1006
+ [
1007
+ condition_images_condition,
1008
+ batch_condition,
1009
+ ],
1010
+ axis=2,
1011
+ )
1012
+ # 此时 batch_condition 比 batch 多了一帧,为了最终视频能 concat 存储,替换下
1013
+ # 当前仅适用于 condition_images_condition 不为None
1014
+ # when condition_images_condition is not None, batch_condition has more frames than batch
1015
+ batch = rearrange(batch_condition, "b c t h w ->(b t) h w c")
1016
+ else:
1017
+ batch_condition = [
1018
+ rearrange(x, "(b t) c h w-> b c t h w", b=batch_size)
1019
+ for x in batch_condition
1020
+ ]
1021
+ if condition_images_condition is not None:
1022
+ batch_condition = [
1023
+ np.concatenate(
1024
+ [condition_images_condition, batch_condition_tmp],
1025
+ axis=2,
1026
+ )
1027
+ for batch_condition_tmp in batch_condition
1028
+ ]
1029
+ batch = rearrange(batch, "(b t) h w c -> b c t h w", b=batch_size)
1030
+ else:
1031
+ batch_condition = None
1032
+ # condition [0,255]
1033
+ # latent: [0,1]
1034
+ # 按需求生成多个片段,
1035
+ # generate multi video_shot
1036
+ # 第一个片段 会特殊处理,需要生成首帧
1037
+ # first shot is special because of first frame.
1038
+ # 后续片段根据拿前一个片段结果,首尾相连的方式生成。
1039
+ # use last frame of last shot as the first frame of the current shot
1040
+ # TODO: 当前独立拆开实现,待后续融合到一起实现
1041
+ # TODO: to optimize implementation way
1042
+ if n_vision_condition == 0:
1043
+ actual_video_length = video_length
1044
+ control_image = batch_condition
1045
+ first_image_controlnet_condition = None
1046
+ first_image_latents = None
1047
+ if need_video2video:
1048
+ video = batch
1049
+ else:
1050
+ video = None
1051
+ result_overlap = 0
1052
+ else:
1053
+ if i_batch == 0:
1054
+ if self.pipeline.controlnet is not None:
1055
+ if not isinstance(batch_condition, list):
1056
+ first_image_controlnet_condition = batch_condition[
1057
+ :, :, :1, :, :
1058
+ ]
1059
+ else:
1060
+ first_image_controlnet_condition = [
1061
+ x[:, :, :1, :, :] for x in batch_condition
1062
+ ]
1063
+ else:
1064
+ first_image_controlnet_condition = None
1065
+ if need_video2video:
1066
+ if condition_images is None:
1067
+ video = batch[:, :, :1, :, :]
1068
+ else:
1069
+ video = condition_images
1070
+ else:
1071
+ video = None
1072
+ if condition_images is not None and not redraw_condition_image:
1073
+ first_image = condition_images
1074
+ first_image_latents = None
1075
+ else:
1076
+ (
1077
+ first_image,
1078
+ first_image_latents,
1079
+ _,
1080
+ _,
1081
+ _,
1082
+ ) = self.pipeline(
1083
+ prompt=prompt,
1084
+ image=video,
1085
+ control_image=first_image_controlnet_condition,
1086
+ num_inference_steps=num_inference_steps,
1087
+ video_length=1,
1088
+ height=height,
1089
+ width=width,
1090
+ return_dict=False,
1091
+ skip_temporal_layer=True,
1092
+ output_type="np",
1093
+ generator=generator,
1094
+ negative_prompt=negative_prompt,
1095
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
1096
+ control_guidance_start=control_guidance_start,
1097
+ control_guidance_end=control_guidance_end,
1098
+ w_ind_noise=w_ind_noise,
1099
+ strength=strength,
1100
+ refer_image=refer_image
1101
+ if redraw_condition_image_with_referencenet
1102
+ else None,
1103
+ ip_adapter_image=ip_adapter_image
1104
+ if redraw_condition_image_with_ipdapter
1105
+ else None,
1106
+ refer_face_image=refer_face_image
1107
+ if redraw_condition_image_with_facein
1108
+ else None,
1109
+ ip_adapter_scale=ip_adapter_scale,
1110
+ facein_scale=facein_scale,
1111
+ ip_adapter_face_scale=ip_adapter_face_scale,
1112
+ ip_adapter_face_image=refer_face_image
1113
+ if redraw_condition_image_with_ip_adapter_face
1114
+ else None,
1115
+ prompt_only_use_image_prompt=prompt_only_use_image_prompt,
1116
+ )
1117
+ if refer_image is not None:
1118
+ refer_image = first_image * 255.0
1119
+ if ip_adapter_image is not None:
1120
+ ip_adapter_image = first_image * 255.0
1121
+ # 首帧用于后续推断可以直接用first_image_latent不需要 first_image了
1122
+ first_image = None
1123
+ if self.pipeline.controlnet is not None:
1124
+ if not isinstance(batch_condition, list):
1125
+ control_image = batch_condition[:, :, 1:, :, :]
1126
+ logger.debug(f"control_image={control_image.shape}")
1127
+ else:
1128
+ control_image = [x[:, :, 1:, :, :] for x in batch_condition]
1129
+ else:
1130
+ control_image = None
1131
+
1132
+ actual_video_length = time_size - int(video_has_condition)
1133
+ if need_video2video:
1134
+ video = batch[:, :, 1:, :, :]
1135
+ else:
1136
+ video = None
1137
+
1138
+ result_overlap = 0
1139
+ else:
1140
+ actual_video_length = time_size
1141
+ if self.pipeline.controlnet is not None:
1142
+ if not fix_condition_images:
1143
+ logger.debug(
1144
+ f"{i_batch}, update first_image_controlnet_condition"
1145
+ )
1146
+
1147
+ if not isinstance(last_batch_condition, list):
1148
+ first_image_controlnet_condition = last_batch_condition[
1149
+ :, :, -1:, :, :
1150
+ ]
1151
+ else:
1152
+ first_image_controlnet_condition = [
1153
+ x[:, :, -1:, :, :] for x in last_batch_condition
1154
+ ]
1155
+ else:
1156
+ logger.debug(
1157
+ f"{i_batch}, do not update first_image_controlnet_condition"
1158
+ )
1159
+ control_image = batch_condition
1160
+ else:
1161
+ control_image = None
1162
+ first_image_controlnet_condition = None
1163
+ if not fix_condition_images:
1164
+ logger.debug(f"{i_batch}, update condition_images")
1165
+ first_image_latents = out_latents_batch[:, :, -1:, :, :]
1166
+ else:
1167
+ logger.debug(f"{i_batch}, do not update condition_images")
1168
+
1169
+ if need_video2video:
1170
+ video = batch
1171
+ else:
1172
+ video = None
1173
+ result_overlap = 1
1174
+
1175
+ # 更新 ref_image和 ipadapter_image
1176
+ if not fixed_refer_image:
1177
+ logger.debug(
1178
+ "ref_image use last frame of last generated out video"
1179
+ )
1180
+ refer_image = (
1181
+ out_batch[:, :, -n_vision_condition:, :, :] * 255.0
1182
+ )
1183
+ else:
1184
+ logger.debug("use given fixed ref_image")
1185
+
1186
+ if not fixed_ip_adapter_image:
1187
+ logger.debug(
1188
+ "ip_adapter_image use last frame of last generated out video"
1189
+ )
1190
+ ip_adapter_image = (
1191
+ out_batch[:, :, -n_vision_condition:, :, :] * 255.0
1192
+ )
1193
+ else:
1194
+ logger.debug("use given fixed ip_adapter_image")
1195
+
1196
+ # face image
1197
+ if not fixed_ip_adapter_image:
1198
+ logger.debug(
1199
+ "refer_face_image use last frame of last generated out video"
1200
+ )
1201
+ refer_face_image = (
1202
+ out_batch[:, :, -n_vision_condition:, :, :] * 255.0
1203
+ )
1204
+ else:
1205
+ logger.debug("use given fixed ip_adapter_image")
1206
+
1207
+ out = self.pipeline(
1208
+ video_length=actual_video_length, # int
1209
+ prompt=prompt,
1210
+ num_inference_steps=video_num_inference_steps,
1211
+ height=height,
1212
+ width=width,
1213
+ generator=generator,
1214
+ image=video,
1215
+ control_image=control_image, # b ci(3) t hi wi
1216
+ controlnet_condition_images=first_image_controlnet_condition, # b ci(3) t(1) hi wi
1217
+ # controlnet_condition_images=np.zeros_like(
1218
+ # first_image_controlnet_condition
1219
+ # ), # b ci(3) t(1) hi wi
1220
+ condition_images=first_image,
1221
+ condition_latents=first_image_latents, # b co t(1) ho wo
1222
+ skip_temporal_layer=False,
1223
+ output_type="np",
1224
+ noise_type=noise_type,
1225
+ negative_prompt=video_negative_prompt,
1226
+ need_img_based_video_noise=need_img_based_video_noise,
1227
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
1228
+ control_guidance_start=control_guidance_start,
1229
+ control_guidance_end=control_guidance_end,
1230
+ w_ind_noise=w_ind_noise,
1231
+ img_weight=img_weight,
1232
+ motion_speed=video_reader.sample_rate,
1233
+ guidance_scale=video_guidance_scale,
1234
+ guidance_scale_end=video_guidance_scale_end,
1235
+ guidance_scale_method=video_guidance_scale_method,
1236
+ strength=video_strength,
1237
+ refer_image=refer_image,
1238
+ ip_adapter_image=ip_adapter_image,
1239
+ refer_face_image=refer_face_image,
1240
+ ip_adapter_scale=ip_adapter_scale,
1241
+ facein_scale=facein_scale,
1242
+ ip_adapter_face_scale=ip_adapter_face_scale,
1243
+ ip_adapter_face_image=refer_face_image,
1244
+ prompt_only_use_image_prompt=prompt_only_use_image_prompt,
1245
+ initial_common_latent=initial_common_latent,
1246
+ # serial_denoise parameter start
1247
+ record_mid_video_noises=record_mid_video_noises,
1248
+ last_mid_video_noises=last_mid_video_noises,
1249
+ record_mid_video_latents=record_mid_video_latents,
1250
+ last_mid_video_latents=last_mid_video_latents,
1251
+ video_overlap=video_overlap,
1252
+ # serial_denoise parameter end
1253
+ # parallel_denoise parameter start
1254
+ context_schedule=context_schedule,
1255
+ context_frames=context_frames,
1256
+ context_stride=context_stride,
1257
+ context_overlap=context_overlap,
1258
+ context_batch_size=context_batch_size,
1259
+ interpolation_factor=interpolation_factor,
1260
+ # parallel_denoise parameter end
1261
+ )
1262
+ last_batch = batch
1263
+ last_batch_condition = batch_condition
1264
+ last_mid_video_latents = out.mid_video_latents
1265
+ last_mid_video_noises = out.mid_video_noises
1266
+ out_batch = out.videos[:, :, result_overlap:, :, :]
1267
+ out_latents_batch = out.latents[:, :, result_overlap:, :, :]
1268
+ out_videos.append(out_batch)
1269
+ if need_return_videos:
1270
+ videos.append(batch)
1271
+ if out_condition is not None:
1272
+ out_condition.append(batch_condition)
1273
+
1274
+ out_videos = np.concatenate(out_videos, axis=2)
1275
+ if need_return_videos:
1276
+ videos = np.concatenate(videos, axis=2)
1277
+ if out_condition is not None:
1278
+ if not isinstance(out_condition[0], list):
1279
+ out_condition = np.concatenate(out_condition, axis=2)
1280
+ else:
1281
+ out_condition = [
1282
+ [out_condition[j][i] for j in range(len(out_condition))]
1283
+ for i in range(len(out_condition[0]))
1284
+ ]
1285
+ out_condition = [np.concatenate(x, axis=2) for x in out_condition]
1286
+ if need_hist_match:
1287
+ videos[:, :, 1:, :, :] = hist_match_video_bcthw(
1288
+ videos[:, :, 1:, :, :], videos[:, :, :1, :, :], value=255.0
1289
+ )
1290
+ return out_videos, out_condition, videos
musev/schedulers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
2
+ from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
3
+ from .scheduling_euler_discrete import EulerDiscreteScheduler
4
+ from .scheduling_lcm import LCMScheduler
5
+ from .scheduling_ddim import DDIMScheduler
6
+ from .scheduling_ddpm import DDPMScheduler
musev/schedulers/scheduling_ddim.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ from dataclasses import dataclass
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+ from numpy import ndarray
26
+ import torch
27
+
28
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
29
+ from diffusers.utils import BaseOutput
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.schedulers.scheduling_utils import (
32
+ KarrasDiffusionSchedulers,
33
+ SchedulerMixin,
34
+ )
35
+ from diffusers.schedulers.scheduling_ddim import (
36
+ DDIMSchedulerOutput,
37
+ rescale_zero_terminal_snr,
38
+ betas_for_alpha_bar,
39
+ DDIMScheduler as DiffusersDDIMScheduler,
40
+ )
41
+ from ..utils.noise_util import video_fusion_noise
42
+
43
+
44
+ class DDIMScheduler(DiffusersDDIMScheduler):
45
+ """
46
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
47
+ non-Markovian guidance.
48
+
49
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
50
+ methods the library implements for all schedulers such as loading and saving.
51
+
52
+ Args:
53
+ num_train_timesteps (`int`, defaults to 1000):
54
+ The number of diffusion steps to train the model.
55
+ beta_start (`float`, defaults to 0.0001):
56
+ The starting `beta` value of inference.
57
+ beta_end (`float`, defaults to 0.02):
58
+ The final `beta` value.
59
+ beta_schedule (`str`, defaults to `"linear"`):
60
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
61
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
62
+ trained_betas (`np.ndarray`, *optional*):
63
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
64
+ clip_sample (`bool`, defaults to `True`):
65
+ Clip the predicted sample for numerical stability.
66
+ clip_sample_range (`float`, defaults to 1.0):
67
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
68
+ set_alpha_to_one (`bool`, defaults to `True`):
69
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
70
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
71
+ otherwise it uses the alpha value at step 0.
72
+ steps_offset (`int`, defaults to 0):
73
+ An offset added to the inference steps. You can use a combination of `offset=1` and
74
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
75
+ Diffusion.
76
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
77
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
78
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
79
+ Video](https://imagen.research.google/video/paper.pdf) paper).
80
+ thresholding (`bool`, defaults to `False`):
81
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
82
+ as Stable Diffusion.
83
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
84
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
85
+ sample_max_value (`float`, defaults to 1.0):
86
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
87
+ timestep_spacing (`str`, defaults to `"leading"`):
88
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
89
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
90
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
91
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
92
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
93
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
94
+ """
95
+
96
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
97
+ order = 1
98
+
99
+ @register_to_config
100
+ def __init__(
101
+ self,
102
+ num_train_timesteps: int = 1000,
103
+ beta_start: float = 0.0001,
104
+ beta_end: float = 0.02,
105
+ beta_schedule: str = "linear",
106
+ trained_betas: ndarray | List[float] | None = None,
107
+ clip_sample: bool = True,
108
+ set_alpha_to_one: bool = True,
109
+ steps_offset: int = 0,
110
+ prediction_type: str = "epsilon",
111
+ thresholding: bool = False,
112
+ dynamic_thresholding_ratio: float = 0.995,
113
+ clip_sample_range: float = 1,
114
+ sample_max_value: float = 1,
115
+ timestep_spacing: str = "leading",
116
+ rescale_betas_zero_snr: bool = False,
117
+ ):
118
+ super().__init__(
119
+ num_train_timesteps,
120
+ beta_start,
121
+ beta_end,
122
+ beta_schedule,
123
+ trained_betas,
124
+ clip_sample,
125
+ set_alpha_to_one,
126
+ steps_offset,
127
+ prediction_type,
128
+ thresholding,
129
+ dynamic_thresholding_ratio,
130
+ clip_sample_range,
131
+ sample_max_value,
132
+ timestep_spacing,
133
+ rescale_betas_zero_snr,
134
+ )
135
+
136
+ def step(
137
+ self,
138
+ model_output: torch.FloatTensor,
139
+ timestep: int,
140
+ sample: torch.FloatTensor,
141
+ eta: float = 0.0,
142
+ use_clipped_model_output: bool = False,
143
+ generator=None,
144
+ variance_noise: Optional[torch.FloatTensor] = None,
145
+ return_dict: bool = True,
146
+ w_ind_noise: float = 0.5,
147
+ noise_type: str = "random",
148
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
149
+ """
150
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
151
+ process from the learned model outputs (most often the predicted noise).
152
+
153
+ Args:
154
+ model_output (`torch.FloatTensor`):
155
+ The direct output from learned diffusion model.
156
+ timestep (`float`):
157
+ The current discrete timestep in the diffusion chain.
158
+ sample (`torch.FloatTensor`):
159
+ A current instance of a sample created by the diffusion process.
160
+ eta (`float`):
161
+ The weight of noise for added noise in diffusion step.
162
+ use_clipped_model_output (`bool`, defaults to `False`):
163
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
164
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
165
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
166
+ `use_clipped_model_output` has no effect.
167
+ generator (`torch.Generator`, *optional*):
168
+ A random number generator.
169
+ variance_noise (`torch.FloatTensor`):
170
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
171
+ itself. Useful for methods such as [`CycleDiffusion`].
172
+ return_dict (`bool`, *optional*, defaults to `True`):
173
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
174
+
175
+ Returns:
176
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
177
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
178
+ tuple is returned where the first element is the sample tensor.
179
+
180
+ """
181
+ if self.num_inference_steps is None:
182
+ raise ValueError(
183
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
184
+ )
185
+
186
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
187
+ # Ideally, read DDIM paper in-detail understanding
188
+
189
+ # Notation (<variable name> -> <name in paper>
190
+ # - pred_noise_t -> e_theta(x_t, t)
191
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
192
+ # - std_dev_t -> sigma_t
193
+ # - eta -> η
194
+ # - pred_sample_direction -> "direction pointing to x_t"
195
+ # - pred_prev_sample -> "x_t-1"
196
+
197
+ # 1. get previous step value (=t-1)
198
+ prev_timestep = (
199
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
200
+ )
201
+
202
+ # 2. compute alphas, betas
203
+ alpha_prod_t = self.alphas_cumprod[timestep]
204
+ alpha_prod_t_prev = (
205
+ self.alphas_cumprod[prev_timestep]
206
+ if prev_timestep >= 0
207
+ else self.final_alpha_cumprod
208
+ )
209
+
210
+ beta_prod_t = 1 - alpha_prod_t
211
+
212
+ # 3. compute predicted original sample from predicted noise also called
213
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
214
+ if self.config.prediction_type == "epsilon":
215
+ pred_original_sample = (
216
+ sample - beta_prod_t ** (0.5) * model_output
217
+ ) / alpha_prod_t ** (0.5)
218
+ pred_epsilon = model_output
219
+ elif self.config.prediction_type == "sample":
220
+ pred_original_sample = model_output
221
+ pred_epsilon = (
222
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
223
+ ) / beta_prod_t ** (0.5)
224
+ elif self.config.prediction_type == "v_prediction":
225
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
226
+ beta_prod_t**0.5
227
+ ) * model_output
228
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
229
+ beta_prod_t**0.5
230
+ ) * sample
231
+ else:
232
+ raise ValueError(
233
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
234
+ " `v_prediction`"
235
+ )
236
+
237
+ # 4. Clip or threshold "predicted x_0"
238
+ if self.config.thresholding:
239
+ pred_original_sample = self._threshold_sample(pred_original_sample)
240
+ elif self.config.clip_sample:
241
+ pred_original_sample = pred_original_sample.clamp(
242
+ -self.config.clip_sample_range, self.config.clip_sample_range
243
+ )
244
+
245
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
246
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
247
+ variance = self._get_variance(timestep, prev_timestep)
248
+ std_dev_t = eta * variance ** (0.5)
249
+
250
+ if use_clipped_model_output:
251
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
252
+ pred_epsilon = (
253
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
254
+ ) / beta_prod_t ** (0.5)
255
+
256
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
257
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
258
+ 0.5
259
+ ) * pred_epsilon
260
+
261
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
262
+ prev_sample = (
263
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
264
+ )
265
+
266
+ if eta > 0:
267
+ if variance_noise is not None and generator is not None:
268
+ raise ValueError(
269
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
270
+ " `variance_noise` stays `None`."
271
+ )
272
+
273
+ # if variance_noise is None:
274
+ # variance_noise = randn_tensor(
275
+ # model_output.shape,
276
+ # generator=generator,
277
+ # device=model_output.device,
278
+ # dtype=model_output.dtype,
279
+ # )
280
+ device = model_output.device
281
+
282
+ if noise_type == "random":
283
+ variance_noise = randn_tensor(
284
+ model_output.shape,
285
+ dtype=model_output.dtype,
286
+ device=device,
287
+ generator=generator,
288
+ )
289
+ elif noise_type == "video_fusion":
290
+ variance_noise = video_fusion_noise(
291
+ model_output, w_ind_noise=w_ind_noise, generator=generator
292
+ )
293
+ variance = std_dev_t * variance_noise
294
+
295
+ prev_sample = prev_sample + variance
296
+
297
+ if not return_dict:
298
+ return (prev_sample,)
299
+
300
+ return DDIMSchedulerOutput(
301
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
302
+ )
musev/schedulers/scheduling_ddpm.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ from numpy import ndarray
25
+ import torch
26
+
27
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
28
+ from diffusers.utils import BaseOutput
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.schedulers.scheduling_utils import (
31
+ KarrasDiffusionSchedulers,
32
+ SchedulerMixin,
33
+ )
34
+ from diffusers.schedulers.scheduling_ddpm import (
35
+ DDPMSchedulerOutput,
36
+ betas_for_alpha_bar,
37
+ DDPMScheduler as DiffusersDDPMScheduler,
38
+ )
39
+ from ..utils.noise_util import video_fusion_noise
40
+
41
+
42
+ class DDPMScheduler(DiffusersDDPMScheduler):
43
+ """
44
+ `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
45
+
46
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
47
+ methods the library implements for all schedulers such as loading and saving.
48
+
49
+ Args:
50
+ num_train_timesteps (`int`, defaults to 1000):
51
+ The number of diffusion steps to train the model.
52
+ beta_start (`float`, defaults to 0.0001):
53
+ The starting `beta` value of inference.
54
+ beta_end (`float`, defaults to 0.02):
55
+ The final `beta` value.
56
+ beta_schedule (`str`, defaults to `"linear"`):
57
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
58
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
59
+ variance_type (`str`, defaults to `"fixed_small"`):
60
+ Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`,
61
+ `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
62
+ clip_sample (`bool`, defaults to `True`):
63
+ Clip the predicted sample for numerical stability.
64
+ clip_sample_range (`float`, defaults to 1.0):
65
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
66
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
67
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
68
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
69
+ Video](https://imagen.research.google/video/paper.pdf) paper).
70
+ thresholding (`bool`, defaults to `False`):
71
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
72
+ as Stable Diffusion.
73
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
74
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
75
+ sample_max_value (`float`, defaults to 1.0):
76
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
77
+ timestep_spacing (`str`, defaults to `"leading"`):
78
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
79
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
80
+ steps_offset (`int`, defaults to 0):
81
+ An offset added to the inference steps. You can use a combination of `offset=1` and
82
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
83
+ Diffusion.
84
+ """
85
+
86
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
87
+ order = 1
88
+
89
+ @register_to_config
90
+ def __init__(
91
+ self,
92
+ num_train_timesteps: int = 1000,
93
+ beta_start: float = 0.0001,
94
+ beta_end: float = 0.02,
95
+ beta_schedule: str = "linear",
96
+ trained_betas: ndarray | List[float] | None = None,
97
+ variance_type: str = "fixed_small",
98
+ clip_sample: bool = True,
99
+ prediction_type: str = "epsilon",
100
+ thresholding: bool = False,
101
+ dynamic_thresholding_ratio: float = 0.995,
102
+ clip_sample_range: float = 1,
103
+ sample_max_value: float = 1,
104
+ timestep_spacing: str = "leading",
105
+ steps_offset: int = 0,
106
+ ):
107
+ super().__init__(
108
+ num_train_timesteps,
109
+ beta_start,
110
+ beta_end,
111
+ beta_schedule,
112
+ trained_betas,
113
+ variance_type,
114
+ clip_sample,
115
+ prediction_type,
116
+ thresholding,
117
+ dynamic_thresholding_ratio,
118
+ clip_sample_range,
119
+ sample_max_value,
120
+ timestep_spacing,
121
+ steps_offset,
122
+ )
123
+
124
+ def step(
125
+ self,
126
+ model_output: torch.FloatTensor,
127
+ timestep: int,
128
+ sample: torch.FloatTensor,
129
+ generator=None,
130
+ return_dict: bool = True,
131
+ w_ind_noise: float = 0.5,
132
+ noise_type: str = "random",
133
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
134
+ """
135
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
136
+ process from the learned model outputs (most often the predicted noise).
137
+
138
+ Args:
139
+ model_output (`torch.FloatTensor`):
140
+ The direct output from learned diffusion model.
141
+ timestep (`float`):
142
+ The current discrete timestep in the diffusion chain.
143
+ sample (`torch.FloatTensor`):
144
+ A current instance of a sample created by the diffusion process.
145
+ generator (`torch.Generator`, *optional*):
146
+ A random number generator.
147
+ return_dict (`bool`, *optional*, defaults to `True`):
148
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
149
+
150
+ Returns:
151
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
152
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
153
+ tuple is returned where the first element is the sample tensor.
154
+
155
+ """
156
+ t = timestep
157
+
158
+ prev_t = self.previous_timestep(t)
159
+
160
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
161
+ "learned",
162
+ "learned_range",
163
+ ]:
164
+ model_output, predicted_variance = torch.split(
165
+ model_output, sample.shape[1], dim=1
166
+ )
167
+ else:
168
+ predicted_variance = None
169
+
170
+ # 1. compute alphas, betas
171
+ alpha_prod_t = self.alphas_cumprod[t]
172
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
173
+ beta_prod_t = 1 - alpha_prod_t
174
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
175
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
176
+ current_beta_t = 1 - current_alpha_t
177
+
178
+ # 2. compute predicted original sample from predicted noise also called
179
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
180
+ if self.config.prediction_type == "epsilon":
181
+ pred_original_sample = (
182
+ sample - beta_prod_t ** (0.5) * model_output
183
+ ) / alpha_prod_t ** (0.5)
184
+ elif self.config.prediction_type == "sample":
185
+ pred_original_sample = model_output
186
+ elif self.config.prediction_type == "v_prediction":
187
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
188
+ beta_prod_t**0.5
189
+ ) * model_output
190
+ else:
191
+ raise ValueError(
192
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
193
+ " `v_prediction` for the DDPMScheduler."
194
+ )
195
+
196
+ # 3. Clip or threshold "predicted x_0"
197
+ if self.config.thresholding:
198
+ pred_original_sample = self._threshold_sample(pred_original_sample)
199
+ elif self.config.clip_sample:
200
+ pred_original_sample = pred_original_sample.clamp(
201
+ -self.config.clip_sample_range, self.config.clip_sample_range
202
+ )
203
+
204
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
205
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
206
+ pred_original_sample_coeff = (
207
+ alpha_prod_t_prev ** (0.5) * current_beta_t
208
+ ) / beta_prod_t
209
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
210
+
211
+ # 5. Compute predicted previous sample µ_t
212
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
213
+ pred_prev_sample = (
214
+ pred_original_sample_coeff * pred_original_sample
215
+ + current_sample_coeff * sample
216
+ )
217
+
218
+ # 6. Add noise
219
+ variance = 0
220
+ if t > 0:
221
+ device = model_output.device
222
+ # if variance_noise is None:
223
+ # variance_noise = randn_tensor(
224
+ # model_output.shape,
225
+ # generator=generator,
226
+ # device=model_output.device,
227
+ # dtype=model_output.dtype,
228
+ # )
229
+ device = model_output.device
230
+
231
+ if noise_type == "random":
232
+ variance_noise = randn_tensor(
233
+ model_output.shape,
234
+ dtype=model_output.dtype,
235
+ device=device,
236
+ generator=generator,
237
+ )
238
+ elif noise_type == "video_fusion":
239
+ variance_noise = video_fusion_noise(
240
+ model_output, w_ind_noise=w_ind_noise, generator=generator
241
+ )
242
+ if self.variance_type == "fixed_small_log":
243
+ variance = (
244
+ self._get_variance(t, predicted_variance=predicted_variance)
245
+ * variance_noise
246
+ )
247
+ elif self.variance_type == "learned_range":
248
+ variance = self._get_variance(t, predicted_variance=predicted_variance)
249
+ variance = torch.exp(0.5 * variance) * variance_noise
250
+ else:
251
+ variance = (
252
+ self._get_variance(t, predicted_variance=predicted_variance) ** 0.5
253
+ ) * variance_noise
254
+
255
+ pred_prev_sample = pred_prev_sample + variance
256
+
257
+ if not return_dict:
258
+ return (pred_prev_sample,)
259
+
260
+ return DDPMSchedulerOutput(
261
+ prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample
262
+ )
musev/schedulers/scheduling_dpmsolver_multistep.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+
25
+ try:
26
+ from diffusers.utils import randn_tensor
27
+ except:
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.schedulers.scheduling_utils import (
30
+ KarrasDiffusionSchedulers,
31
+ SchedulerMixin,
32
+ SchedulerOutput,
33
+ )
34
+
35
+
36
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
37
+ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
38
+ """
39
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
40
+ (1-beta) over time from t = [0,1].
41
+
42
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
43
+ to that part of the diffusion process.
44
+
45
+
46
+ Args:
47
+ num_diffusion_timesteps (`int`): the number of betas to produce.
48
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
49
+ prevent singularities.
50
+
51
+ Returns:
52
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
53
+ """
54
+
55
+ def alpha_bar(time_step):
56
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
57
+
58
+ betas = []
59
+ for i in range(num_diffusion_timesteps):
60
+ t1 = i / num_diffusion_timesteps
61
+ t2 = (i + 1) / num_diffusion_timesteps
62
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
63
+ return torch.tensor(betas, dtype=torch.float32)
64
+
65
+
66
+ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
67
+ """
68
+ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
69
+ the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
70
+ samples, and it can generate quite good samples even in only 10 steps.
71
+
72
+ For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
73
+
74
+ Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
75
+ recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
76
+
77
+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
78
+ diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
79
+ thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
80
+ stable-diffusion).
81
+
82
+ We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse
83
+ diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the
84
+ second-order `sde-dpmsolver++`.
85
+
86
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
87
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
88
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
89
+ [`~SchedulerMixin.from_pretrained`] functions.
90
+
91
+ Args:
92
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
93
+ beta_start (`float`): the starting `beta` value of inference.
94
+ beta_end (`float`): the final `beta` value.
95
+ beta_schedule (`str`):
96
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
97
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
98
+ trained_betas (`np.ndarray`, optional):
99
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
100
+ solver_order (`int`, default `2`):
101
+ the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
102
+ sampling, and `solver_order=3` for unconditional sampling.
103
+ prediction_type (`str`, default `epsilon`, optional):
104
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
105
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
106
+ https://imagen.research.google/video/paper.pdf)
107
+ thresholding (`bool`, default `False`):
108
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
109
+ For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
110
+ use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
111
+ models (such as stable-diffusion).
112
+ dynamic_thresholding_ratio (`float`, default `0.995`):
113
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
114
+ (https://arxiv.org/abs/2205.11487).
115
+ sample_max_value (`float`, default `1.0`):
116
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
117
+ `algorithm_type="dpmsolver++`.
118
+ algorithm_type (`str`, default `dpmsolver++`):
119
+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or
120
+ `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
121
+ the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
122
+ `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
123
+ solver_type (`str`, default `midpoint`):
124
+ the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
125
+ the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
126
+ slightly better, so we recommend to use the `midpoint` type.
127
+ lower_order_final (`bool`, default `True`):
128
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
129
+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
130
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
131
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
132
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
133
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
134
+ lambda_min_clipped (`float`, default `-inf`):
135
+ the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
136
+ cosine (squaredcos_cap_v2) noise schedule.
137
+ variance_type (`str`, *optional*):
138
+ Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
139
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
140
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
141
+ diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
142
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
143
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
144
+ diffusion ODEs.
145
+ """
146
+
147
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
148
+ order = 1
149
+
150
+ @register_to_config
151
+ def __init__(
152
+ self,
153
+ num_train_timesteps: int = 1000,
154
+ beta_start: float = 0.0001,
155
+ beta_end: float = 0.02,
156
+ beta_schedule: str = "linear",
157
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
158
+ solver_order: int = 2,
159
+ prediction_type: str = "epsilon",
160
+ thresholding: bool = False,
161
+ dynamic_thresholding_ratio: float = 0.995,
162
+ sample_max_value: float = 1.0,
163
+ algorithm_type: str = "dpmsolver++",
164
+ solver_type: str = "midpoint",
165
+ lower_order_final: bool = True,
166
+ use_karras_sigmas: Optional[bool] = True,
167
+ lambda_min_clipped: float = -float("inf"),
168
+ variance_type: Optional[str] = None,
169
+ ):
170
+ if trained_betas is not None:
171
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
172
+ elif beta_schedule == "linear":
173
+ self.betas = torch.linspace(
174
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
175
+ )
176
+ elif beta_schedule == "scaled_linear":
177
+ # this schedule is very specific to the latent diffusion model.
178
+ self.betas = (
179
+ torch.linspace(
180
+ beta_start**0.5,
181
+ beta_end**0.5,
182
+ num_train_timesteps,
183
+ dtype=torch.float32,
184
+ )
185
+ ** 2
186
+ )
187
+ elif beta_schedule == "squaredcos_cap_v2":
188
+ # Glide cosine schedule
189
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
190
+ else:
191
+ raise NotImplementedError(
192
+ f"{beta_schedule} does is not implemented for {self.__class__}"
193
+ )
194
+
195
+ self.alphas = 1.0 - self.betas
196
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
197
+ # Currently we only support VP-type noise schedule
198
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
199
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
200
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
201
+
202
+ # standard deviation of the initial noise distribution
203
+ self.init_noise_sigma = 1.0
204
+
205
+ # settings for DPM-Solver
206
+ if algorithm_type not in [
207
+ "dpmsolver",
208
+ "dpmsolver++",
209
+ "sde-dpmsolver",
210
+ "sde-dpmsolver++",
211
+ ]:
212
+ if algorithm_type == "deis":
213
+ self.register_to_config(algorithm_type="dpmsolver++")
214
+ else:
215
+ raise NotImplementedError(
216
+ f"{algorithm_type} does is not implemented for {self.__class__}"
217
+ )
218
+
219
+ if solver_type not in ["midpoint", "heun"]:
220
+ if solver_type in ["logrho", "bh1", "bh2"]:
221
+ self.register_to_config(solver_type="midpoint")
222
+ else:
223
+ raise NotImplementedError(
224
+ f"{solver_type} does is not implemented for {self.__class__}"
225
+ )
226
+
227
+ # setable values
228
+ self.num_inference_steps = None
229
+ timesteps = np.linspace(
230
+ 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32
231
+ )[::-1].copy()
232
+ self.timesteps = torch.from_numpy(timesteps)
233
+ self.model_outputs = [None] * solver_order
234
+ self.lower_order_nums = 0
235
+ self.use_karras_sigmas = use_karras_sigmas
236
+
237
+ def set_timesteps(
238
+ self, num_inference_steps: int = None, device: Union[str, torch.device] = None
239
+ ):
240
+ """
241
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
242
+
243
+ Args:
244
+ num_inference_steps (`int`):
245
+ the number of diffusion steps used when generating samples with a pre-trained model.
246
+ device (`str` or `torch.device`, optional):
247
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
248
+ """
249
+ # Clipping the minimum of all lambda(t) for numerical stability.
250
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
251
+ clipped_idx = torch.searchsorted(
252
+ torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped
253
+ )
254
+ timesteps = (
255
+ np.linspace(
256
+ 0,
257
+ self.config.num_train_timesteps - 1 - clipped_idx,
258
+ num_inference_steps + 1,
259
+ )
260
+ .round()[::-1][:-1]
261
+ .copy()
262
+ .astype(np.int64)
263
+ )
264
+
265
+ if self.use_karras_sigmas:
266
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
267
+ log_sigmas = np.log(sigmas)
268
+ sigmas = self._convert_to_karras(
269
+ in_sigmas=sigmas, num_inference_steps=num_inference_steps
270
+ )
271
+ timesteps = np.array(
272
+ [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
273
+ ).round()
274
+ timesteps = np.flip(timesteps).copy().astype(np.int64)
275
+
276
+ # when num_inference_steps == num_train_timesteps, we can end up with
277
+ # duplicates in timesteps.
278
+ _, unique_indices = np.unique(timesteps, return_index=True)
279
+ timesteps = timesteps[np.sort(unique_indices)]
280
+
281
+ self.timesteps = torch.from_numpy(timesteps).to(device)
282
+
283
+ self.num_inference_steps = len(timesteps)
284
+
285
+ self.model_outputs = [
286
+ None,
287
+ ] * self.config.solver_order
288
+ self.lower_order_nums = 0
289
+
290
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
291
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
292
+ """
293
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
294
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
295
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
296
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
297
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
298
+
299
+ https://arxiv.org/abs/2205.11487
300
+ """
301
+ dtype = sample.dtype
302
+ batch_size, channels, height, width = sample.shape
303
+
304
+ if dtype not in (torch.float32, torch.float64):
305
+ sample = (
306
+ sample.float()
307
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
308
+
309
+ # Flatten sample for doing quantile calculation along each image
310
+ sample = sample.reshape(batch_size, channels * height * width)
311
+
312
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
313
+
314
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
315
+ s = torch.clamp(
316
+ s, min=1, max=self.config.sample_max_value
317
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
318
+
319
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
320
+ sample = (
321
+ torch.clamp(sample, -s, s) / s
322
+ ) # "we threshold xt0 to the range [-s, s] and then divide by s"
323
+
324
+ sample = sample.reshape(batch_size, channels, height, width)
325
+ sample = sample.to(dtype)
326
+
327
+ return sample
328
+
329
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
330
+ def _sigma_to_t(self, sigma, log_sigmas):
331
+ # get log sigma
332
+ log_sigma = np.log(sigma)
333
+
334
+ # get distribution
335
+ dists = log_sigma - log_sigmas[:, np.newaxis]
336
+
337
+ # get sigmas range
338
+ low_idx = (
339
+ np.cumsum((dists >= 0), axis=0)
340
+ .argmax(axis=0)
341
+ .clip(max=log_sigmas.shape[0] - 2)
342
+ )
343
+ high_idx = low_idx + 1
344
+
345
+ low = log_sigmas[low_idx]
346
+ high = log_sigmas[high_idx]
347
+
348
+ # interpolate sigmas
349
+ w = (low - log_sigma) / (low - high)
350
+ w = np.clip(w, 0, 1)
351
+
352
+ # transform interpolation to time range
353
+ t = (1 - w) * low_idx + w * high_idx
354
+ t = t.reshape(sigma.shape)
355
+ return t
356
+
357
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
358
+ def _convert_to_karras(
359
+ self, in_sigmas: torch.FloatTensor, num_inference_steps
360
+ ) -> torch.FloatTensor:
361
+ """Constructs the noise schedule of Karras et al. (2022)."""
362
+
363
+ sigma_min: float = in_sigmas[-1].item()
364
+ sigma_max: float = in_sigmas[0].item()
365
+
366
+ rho = 7.0 # 7.0 is the value used in the paper
367
+ ramp = np.linspace(0, 1, num_inference_steps)
368
+ min_inv_rho = sigma_min ** (1 / rho)
369
+ max_inv_rho = sigma_max ** (1 / rho)
370
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
371
+ return sigmas
372
+
373
+ def convert_model_output(
374
+ self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
375
+ ) -> torch.FloatTensor:
376
+ """
377
+ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
378
+
379
+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
380
+ discretize an integral of the data prediction model. So we need to first convert the model output to the
381
+ corresponding type to match the algorithm.
382
+
383
+ Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
384
+ DPM-Solver++ for both noise prediction model and data prediction model.
385
+
386
+ Args:
387
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
388
+ timestep (`int`): current discrete timestep in the diffusion chain.
389
+ sample (`torch.FloatTensor`):
390
+ current instance of sample being created by diffusion process.
391
+
392
+ Returns:
393
+ `torch.FloatTensor`: the converted model output.
394
+ """
395
+
396
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
397
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
398
+ if self.config.prediction_type == "epsilon":
399
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
400
+ if self.config.variance_type in ["learned", "learned_range"]:
401
+ model_output = model_output[:, :3]
402
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
403
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
404
+ elif self.config.prediction_type == "sample":
405
+ x0_pred = model_output
406
+ elif self.config.prediction_type == "v_prediction":
407
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
408
+ x0_pred = alpha_t * sample - sigma_t * model_output
409
+ else:
410
+ raise ValueError(
411
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
412
+ " `v_prediction` for the DPMSolverMultistepScheduler."
413
+ )
414
+
415
+ if self.config.thresholding:
416
+ x0_pred = self._threshold_sample(x0_pred)
417
+
418
+ return x0_pred
419
+
420
+ # DPM-Solver needs to solve an integral of the noise prediction model.
421
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
422
+ if self.config.prediction_type == "epsilon":
423
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
424
+ if self.config.variance_type in ["learned", "learned_range"]:
425
+ epsilon = model_output[:, :3]
426
+ else:
427
+ epsilon = model_output
428
+ elif self.config.prediction_type == "sample":
429
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
430
+ epsilon = (sample - alpha_t * model_output) / sigma_t
431
+ elif self.config.prediction_type == "v_prediction":
432
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
433
+ epsilon = alpha_t * model_output + sigma_t * sample
434
+ else:
435
+ raise ValueError(
436
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
437
+ " `v_prediction` for the DPMSolverMultistepScheduler."
438
+ )
439
+
440
+ if self.config.thresholding:
441
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
442
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
443
+ x0_pred = self._threshold_sample(x0_pred)
444
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
445
+
446
+ return epsilon
447
+
448
+ def dpm_solver_first_order_update(
449
+ self,
450
+ model_output: torch.FloatTensor,
451
+ timestep: int,
452
+ prev_timestep: int,
453
+ sample: torch.FloatTensor,
454
+ noise: Optional[torch.FloatTensor] = None,
455
+ ) -> torch.FloatTensor:
456
+ """
457
+ One step for the first-order DPM-Solver (equivalent to DDIM).
458
+
459
+ See https://arxiv.org/abs/2206.00927 for the detailed derivation.
460
+
461
+ Args:
462
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
463
+ timestep (`int`): current discrete timestep in the diffusion chain.
464
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
465
+ sample (`torch.FloatTensor`):
466
+ current instance of sample being created by diffusion process.
467
+
468
+ Returns:
469
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
470
+ """
471
+ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
472
+ alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
473
+ sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
474
+ h = lambda_t - lambda_s
475
+ if self.config.algorithm_type == "dpmsolver++":
476
+ x_t = (sigma_t / sigma_s) * sample - (
477
+ alpha_t * (torch.exp(-h) - 1.0)
478
+ ) * model_output
479
+ elif self.config.algorithm_type == "dpmsolver":
480
+ x_t = (alpha_t / alpha_s) * sample - (
481
+ sigma_t * (torch.exp(h) - 1.0)
482
+ ) * model_output
483
+ elif self.config.algorithm_type == "sde-dpmsolver++":
484
+ assert noise is not None
485
+ x_t = (
486
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
487
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
488
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
489
+ )
490
+ elif self.config.algorithm_type == "sde-dpmsolver":
491
+ assert noise is not None
492
+ x_t = (
493
+ (alpha_t / alpha_s) * sample
494
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
495
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
496
+ )
497
+ return x_t
498
+
499
+ def multistep_dpm_solver_second_order_update(
500
+ self,
501
+ model_output_list: List[torch.FloatTensor],
502
+ timestep_list: List[int],
503
+ prev_timestep: int,
504
+ sample: torch.FloatTensor,
505
+ noise: Optional[torch.FloatTensor] = None,
506
+ ) -> torch.FloatTensor:
507
+ """
508
+ One step for the second-order multistep DPM-Solver.
509
+
510
+ Args:
511
+ model_output_list (`List[torch.FloatTensor]`):
512
+ direct outputs from learned diffusion model at current and latter timesteps.
513
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
514
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
515
+ sample (`torch.FloatTensor`):
516
+ current instance of sample being created by diffusion process.
517
+
518
+ Returns:
519
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
520
+ """
521
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
522
+ m0, m1 = model_output_list[-1], model_output_list[-2]
523
+ lambda_t, lambda_s0, lambda_s1 = (
524
+ self.lambda_t[t],
525
+ self.lambda_t[s0],
526
+ self.lambda_t[s1],
527
+ )
528
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
529
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
530
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
531
+ r0 = h_0 / h
532
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
533
+ if self.config.algorithm_type == "dpmsolver++":
534
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
535
+ if self.config.solver_type == "midpoint":
536
+ x_t = (
537
+ (sigma_t / sigma_s0) * sample
538
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
539
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
540
+ )
541
+ elif self.config.solver_type == "heun":
542
+ x_t = (
543
+ (sigma_t / sigma_s0) * sample
544
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
545
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
546
+ )
547
+ elif self.config.algorithm_type == "dpmsolver":
548
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
549
+ if self.config.solver_type == "midpoint":
550
+ x_t = (
551
+ (alpha_t / alpha_s0) * sample
552
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
553
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
554
+ )
555
+ elif self.config.solver_type == "heun":
556
+ x_t = (
557
+ (alpha_t / alpha_s0) * sample
558
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
559
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
560
+ )
561
+ elif self.config.algorithm_type == "sde-dpmsolver++":
562
+ assert noise is not None
563
+ if self.config.solver_type == "midpoint":
564
+ x_t = (
565
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
566
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
567
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
568
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
569
+ )
570
+ elif self.config.solver_type == "heun":
571
+ x_t = (
572
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
573
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
574
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
575
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
576
+ )
577
+ elif self.config.algorithm_type == "sde-dpmsolver":
578
+ assert noise is not None
579
+ if self.config.solver_type == "midpoint":
580
+ x_t = (
581
+ (alpha_t / alpha_s0) * sample
582
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
583
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
584
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
585
+ )
586
+ elif self.config.solver_type == "heun":
587
+ x_t = (
588
+ (alpha_t / alpha_s0) * sample
589
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
590
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
591
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
592
+ )
593
+ return x_t
594
+
595
+ def multistep_dpm_solver_third_order_update(
596
+ self,
597
+ model_output_list: List[torch.FloatTensor],
598
+ timestep_list: List[int],
599
+ prev_timestep: int,
600
+ sample: torch.FloatTensor,
601
+ ) -> torch.FloatTensor:
602
+ """
603
+ One step for the third-order multistep DPM-Solver.
604
+
605
+ Args:
606
+ model_output_list (`List[torch.FloatTensor]`):
607
+ direct outputs from learned diffusion model at current and latter timesteps.
608
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
609
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
610
+ sample (`torch.FloatTensor`):
611
+ current instance of sample being created by diffusion process.
612
+
613
+ Returns:
614
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
615
+ """
616
+ t, s0, s1, s2 = (
617
+ prev_timestep,
618
+ timestep_list[-1],
619
+ timestep_list[-2],
620
+ timestep_list[-3],
621
+ )
622
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
623
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
624
+ self.lambda_t[t],
625
+ self.lambda_t[s0],
626
+ self.lambda_t[s1],
627
+ self.lambda_t[s2],
628
+ )
629
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
630
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
631
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
632
+ r0, r1 = h_0 / h, h_1 / h
633
+ D0 = m0
634
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
635
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
636
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
637
+ if self.config.algorithm_type == "dpmsolver++":
638
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
639
+ x_t = (
640
+ (sigma_t / sigma_s0) * sample
641
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
642
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
643
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
644
+ )
645
+ elif self.config.algorithm_type == "dpmsolver":
646
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
647
+ x_t = (
648
+ (alpha_t / alpha_s0) * sample
649
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
650
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
651
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
652
+ )
653
+ return x_t
654
+
655
+ def step(
656
+ self,
657
+ model_output: torch.FloatTensor,
658
+ timestep: int,
659
+ sample: torch.FloatTensor,
660
+ generator=None,
661
+ return_dict: bool = True,
662
+ w_ind_noise: float = 0.5,
663
+ ) -> Union[SchedulerOutput, Tuple]:
664
+ """
665
+ Step function propagating the sample with the multistep DPM-Solver.
666
+
667
+ Args:
668
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
669
+ timestep (`int`): current discrete timestep in the diffusion chain.
670
+ sample (`torch.FloatTensor`):
671
+ current instance of sample being created by diffusion process.
672
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
673
+
674
+ Returns:
675
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
676
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
677
+
678
+ """
679
+ if self.num_inference_steps is None:
680
+ raise ValueError(
681
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
682
+ )
683
+
684
+ if isinstance(timestep, torch.Tensor):
685
+ timestep = timestep.to(self.timesteps.device)
686
+ step_index = (self.timesteps == timestep).nonzero()
687
+ if len(step_index) == 0:
688
+ step_index = len(self.timesteps) - 1
689
+ else:
690
+ step_index = step_index.item()
691
+ prev_timestep = (
692
+ 0
693
+ if step_index == len(self.timesteps) - 1
694
+ else self.timesteps[step_index + 1]
695
+ )
696
+ lower_order_final = (
697
+ (step_index == len(self.timesteps) - 1)
698
+ and self.config.lower_order_final
699
+ and len(self.timesteps) < 15
700
+ )
701
+ lower_order_second = (
702
+ (step_index == len(self.timesteps) - 2)
703
+ and self.config.lower_order_final
704
+ and len(self.timesteps) < 15
705
+ )
706
+
707
+ model_output = self.convert_model_output(model_output, timestep, sample)
708
+ for i in range(self.config.solver_order - 1):
709
+ self.model_outputs[i] = self.model_outputs[i + 1]
710
+ self.model_outputs[-1] = model_output
711
+
712
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
713
+ # noise = randn_tensor(
714
+ # model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
715
+ # )
716
+ common_noise = torch.randn(
717
+ model_output.shape[:2] + (1,) + model_output.shape[3:],
718
+ generator=generator,
719
+ device=model_output.device,
720
+ dtype=model_output.dtype,
721
+ ) # common noise
722
+ ind_noise = randn_tensor(
723
+ model_output.shape,
724
+ generator=generator,
725
+ device=model_output.device,
726
+ dtype=model_output.dtype,
727
+ )
728
+ s = torch.tensor(
729
+ w_ind_noise, device=model_output.device, dtype=model_output.dtype
730
+ ).to(device)
731
+ noise = torch.sqrt(1 - s) * common_noise + torch.sqrt(s) * ind_noise
732
+
733
+ else:
734
+ noise = None
735
+
736
+ if (
737
+ self.config.solver_order == 1
738
+ or self.lower_order_nums < 1
739
+ or lower_order_final
740
+ ):
741
+ prev_sample = self.dpm_solver_first_order_update(
742
+ model_output, timestep, prev_timestep, sample, noise=noise
743
+ )
744
+ elif (
745
+ self.config.solver_order == 2
746
+ or self.lower_order_nums < 2
747
+ or lower_order_second
748
+ ):
749
+ timestep_list = [self.timesteps[step_index - 1], timestep]
750
+ prev_sample = self.multistep_dpm_solver_second_order_update(
751
+ self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
752
+ )
753
+ else:
754
+ timestep_list = [
755
+ self.timesteps[step_index - 2],
756
+ self.timesteps[step_index - 1],
757
+ timestep,
758
+ ]
759
+ prev_sample = self.multistep_dpm_solver_third_order_update(
760
+ self.model_outputs, timestep_list, prev_timestep, sample
761
+ )
762
+
763
+ if self.lower_order_nums < self.config.solver_order:
764
+ self.lower_order_nums += 1
765
+
766
+ if not return_dict:
767
+ return (prev_sample,)
768
+
769
+ return SchedulerOutput(prev_sample=prev_sample)
770
+
771
+ def scale_model_input(
772
+ self, sample: torch.FloatTensor, *args, **kwargs
773
+ ) -> torch.FloatTensor:
774
+ """
775
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
776
+ current timestep.
777
+
778
+ Args:
779
+ sample (`torch.FloatTensor`): input sample
780
+
781
+ Returns:
782
+ `torch.FloatTensor`: scaled input sample
783
+ """
784
+ return sample
785
+
786
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
787
+ def add_noise(
788
+ self,
789
+ original_samples: torch.FloatTensor,
790
+ noise: torch.FloatTensor,
791
+ timesteps: torch.IntTensor,
792
+ ) -> torch.FloatTensor:
793
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
794
+ alphas_cumprod = self.alphas_cumprod.to(
795
+ device=original_samples.device, dtype=original_samples.dtype
796
+ )
797
+ timesteps = timesteps.to(original_samples.device)
798
+
799
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
800
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
801
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
802
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
803
+
804
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
805
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
806
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
807
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
808
+
809
+ noisy_samples = (
810
+ sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
811
+ )
812
+ return noisy_samples
813
+
814
+ def __len__(self):
815
+ return self.config.num_train_timesteps
musev/schedulers/scheduling_euler_ancestral_discrete.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+
24
+ from diffusers.utils import BaseOutput, logging
25
+
26
+ try:
27
+ from diffusers.utils import randn_tensor
28
+ except:
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.schedulers.scheduling_utils import (
31
+ KarrasDiffusionSchedulers,
32
+ SchedulerMixin,
33
+ )
34
+
35
+ from ..utils.noise_util import video_fusion_noise
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ @dataclass
42
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
43
+ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
44
+ """
45
+ Output class for the scheduler's step function output.
46
+
47
+ Args:
48
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
49
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
50
+ denoising loop.
51
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
52
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
53
+ `pred_original_sample` can be used to preview progress or for guidance.
54
+ """
55
+
56
+ prev_sample: torch.FloatTensor
57
+ pred_original_sample: Optional[torch.FloatTensor] = None
58
+
59
+
60
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
61
+ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
62
+ """
63
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
64
+ (1-beta) over time from t = [0,1].
65
+
66
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
67
+ to that part of the diffusion process.
68
+
69
+
70
+ Args:
71
+ num_diffusion_timesteps (`int`): the number of betas to produce.
72
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
73
+ prevent singularities.
74
+
75
+ Returns:
76
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
77
+ """
78
+
79
+ def alpha_bar(time_step):
80
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
81
+
82
+ betas = []
83
+ for i in range(num_diffusion_timesteps):
84
+ t1 = i / num_diffusion_timesteps
85
+ t2 = (i + 1) / num_diffusion_timesteps
86
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
87
+ return torch.tensor(betas, dtype=torch.float32)
88
+
89
+
90
+ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
91
+ """
92
+ Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
93
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
94
+
95
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
96
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
97
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
98
+ [`~SchedulerMixin.from_pretrained`] functions.
99
+
100
+ Args:
101
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
102
+ beta_start (`float`): the starting `beta` value of inference.
103
+ beta_end (`float`): the final `beta` value.
104
+ beta_schedule (`str`):
105
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
106
+ `linear` or `scaled_linear`.
107
+ trained_betas (`np.ndarray`, optional):
108
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
109
+ prediction_type (`str`, default `epsilon`, optional):
110
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
111
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
112
+ https://imagen.research.google/video/paper.pdf)
113
+
114
+ """
115
+
116
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
117
+ order = 1
118
+
119
+ @register_to_config
120
+ def __init__(
121
+ self,
122
+ num_train_timesteps: int = 1000,
123
+ beta_start: float = 0.0001,
124
+ beta_end: float = 0.02,
125
+ beta_schedule: str = "linear",
126
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
127
+ prediction_type: str = "epsilon",
128
+ ):
129
+ if trained_betas is not None:
130
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
131
+ elif beta_schedule == "linear":
132
+ self.betas = torch.linspace(
133
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
134
+ )
135
+ elif beta_schedule == "scaled_linear":
136
+ # this schedule is very specific to the latent diffusion model.
137
+ self.betas = (
138
+ torch.linspace(
139
+ beta_start**0.5,
140
+ beta_end**0.5,
141
+ num_train_timesteps,
142
+ dtype=torch.float32,
143
+ )
144
+ ** 2
145
+ )
146
+ elif beta_schedule == "squaredcos_cap_v2":
147
+ # Glide cosine schedule
148
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
149
+ else:
150
+ raise NotImplementedError(
151
+ f"{beta_schedule} does is not implemented for {self.__class__}"
152
+ )
153
+
154
+ self.alphas = 1.0 - self.betas
155
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
156
+
157
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
158
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
159
+ self.sigmas = torch.from_numpy(sigmas)
160
+
161
+ # standard deviation of the initial noise distribution
162
+ self.init_noise_sigma = self.sigmas.max()
163
+
164
+ # setable values
165
+ self.num_inference_steps = None
166
+ timesteps = np.linspace(
167
+ 0, num_train_timesteps - 1, num_train_timesteps, dtype=float
168
+ )[::-1].copy()
169
+ self.timesteps = torch.from_numpy(timesteps)
170
+ self.is_scale_input_called = False
171
+
172
+ def scale_model_input(
173
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
174
+ ) -> torch.FloatTensor:
175
+ """
176
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
177
+
178
+ Args:
179
+ sample (`torch.FloatTensor`): input sample
180
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
181
+
182
+ Returns:
183
+ `torch.FloatTensor`: scaled input sample
184
+ """
185
+ if isinstance(timestep, torch.Tensor):
186
+ timestep = timestep.to(self.timesteps.device)
187
+ step_index = (self.timesteps == timestep).nonzero().item()
188
+ sigma = self.sigmas[step_index]
189
+ sample = sample / ((sigma**2 + 1) ** 0.5)
190
+ self.is_scale_input_called = True
191
+ return sample
192
+
193
+ def set_timesteps(
194
+ self, num_inference_steps: int, device: Union[str, torch.device] = None
195
+ ):
196
+ """
197
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
198
+
199
+ Args:
200
+ num_inference_steps (`int`):
201
+ the number of diffusion steps used when generating samples with a pre-trained model.
202
+ device (`str` or `torch.device`, optional):
203
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
204
+ """
205
+ self.num_inference_steps = num_inference_steps
206
+
207
+ timesteps = np.linspace(
208
+ 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float
209
+ )[::-1].copy()
210
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
211
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
212
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
213
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
214
+ if str(device).startswith("mps"):
215
+ # mps does not support float64
216
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
217
+ else:
218
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
219
+
220
+ def step(
221
+ self,
222
+ model_output: torch.FloatTensor,
223
+ timestep: Union[float, torch.FloatTensor],
224
+ sample: torch.FloatTensor,
225
+ generator: Optional[torch.Generator] = None,
226
+ return_dict: bool = True,
227
+ w_ind_noise: float = 0.5,
228
+ noise_type: str = "random",
229
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
230
+ """
231
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
232
+ process from the learned model outputs (most often the predicted noise).
233
+
234
+ Args:
235
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
236
+ timestep (`float`): current timestep in the diffusion chain.
237
+ sample (`torch.FloatTensor`):
238
+ current instance of sample being created by diffusion process.
239
+ generator (`torch.Generator`, optional): Random number generator.
240
+ return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
241
+
242
+ Returns:
243
+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
244
+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
245
+ a `tuple`. When returning a tuple, the first element is the sample tensor.
246
+
247
+ """
248
+
249
+ if (
250
+ isinstance(timestep, int)
251
+ or isinstance(timestep, torch.IntTensor)
252
+ or isinstance(timestep, torch.LongTensor)
253
+ ):
254
+ raise ValueError(
255
+ (
256
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
257
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
258
+ " one of the `scheduler.timesteps` as a timestep."
259
+ ),
260
+ )
261
+
262
+ if not self.is_scale_input_called:
263
+ logger.warning(
264
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
265
+ "See `StableDiffusionPipeline` for a usage example."
266
+ )
267
+
268
+ if isinstance(timestep, torch.Tensor):
269
+ timestep = timestep.to(self.timesteps.device)
270
+
271
+ step_index = (self.timesteps == timestep).nonzero().item()
272
+ sigma = self.sigmas[step_index]
273
+
274
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
275
+ if self.config.prediction_type == "epsilon":
276
+ pred_original_sample = sample - sigma * model_output
277
+ elif self.config.prediction_type == "v_prediction":
278
+ # * c_out + input * c_skip
279
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
280
+ sample / (sigma**2 + 1)
281
+ )
282
+ elif self.config.prediction_type == "sample":
283
+ raise NotImplementedError("prediction_type not implemented yet: sample")
284
+ else:
285
+ raise ValueError(
286
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
287
+ )
288
+
289
+ sigma_from = self.sigmas[step_index]
290
+ sigma_to = self.sigmas[step_index + 1]
291
+ sigma_up = (
292
+ sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
293
+ ) ** 0.5
294
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
295
+
296
+ # 2. Convert to an ODE derivative
297
+ derivative = (sample - pred_original_sample) / sigma
298
+
299
+ dt = sigma_down - sigma
300
+
301
+ prev_sample = sample + derivative * dt
302
+
303
+ device = model_output.device
304
+ if noise_type == "random":
305
+ noise = randn_tensor(
306
+ model_output.shape,
307
+ dtype=model_output.dtype,
308
+ device=device,
309
+ generator=generator,
310
+ )
311
+ elif noise_type == "video_fusion":
312
+ noise = video_fusion_noise(
313
+ model_output, w_ind_noise=w_ind_noise, generator=generator
314
+ )
315
+
316
+ prev_sample = prev_sample + noise * sigma_up
317
+
318
+ if not return_dict:
319
+ return (prev_sample,)
320
+
321
+ return EulerAncestralDiscreteSchedulerOutput(
322
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
323
+ )
324
+
325
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
326
+ def add_noise(
327
+ self,
328
+ original_samples: torch.FloatTensor,
329
+ noise: torch.FloatTensor,
330
+ timesteps: torch.FloatTensor,
331
+ ) -> torch.FloatTensor:
332
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
333
+ sigmas = self.sigmas.to(
334
+ device=original_samples.device, dtype=original_samples.dtype
335
+ )
336
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
337
+ # mps does not support float64
338
+ schedule_timesteps = self.timesteps.to(
339
+ original_samples.device, dtype=torch.float32
340
+ )
341
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
342
+ else:
343
+ schedule_timesteps = self.timesteps.to(original_samples.device)
344
+ timesteps = timesteps.to(original_samples.device)
345
+
346
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
347
+
348
+ sigma = sigmas[step_indices].flatten()
349
+ while len(sigma.shape) < len(original_samples.shape):
350
+ sigma = sigma.unsqueeze(-1)
351
+
352
+ noisy_samples = original_samples + noise * sigma
353
+ return noisy_samples
354
+
355
+ def __len__(self):
356
+ return self.config.num_train_timesteps
musev/schedulers/scheduling_euler_discrete.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import logging
3
+
4
+ from typing import List, Optional, Tuple, Union
5
+ import numpy as np
6
+ from numpy import ndarray
7
+ import torch
8
+ from torch import Generator, FloatTensor
9
+ from diffusers.schedulers.scheduling_euler_discrete import (
10
+ EulerDiscreteScheduler as DiffusersEulerDiscreteScheduler,
11
+ EulerDiscreteSchedulerOutput,
12
+ )
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+
15
+ from ..utils.noise_util import video_fusion_noise
16
+
17
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ class EulerDiscreteScheduler(DiffusersEulerDiscreteScheduler):
21
+ def __init__(
22
+ self,
23
+ num_train_timesteps: int = 1000,
24
+ beta_start: float = 0.0001,
25
+ beta_end: float = 0.02,
26
+ beta_schedule: str = "linear",
27
+ trained_betas: ndarray | List[float] | None = None,
28
+ prediction_type: str = "epsilon",
29
+ interpolation_type: str = "linear",
30
+ use_karras_sigmas: bool | None = False,
31
+ timestep_spacing: str = "linspace",
32
+ steps_offset: int = 0,
33
+ ):
34
+ super().__init__(
35
+ num_train_timesteps,
36
+ beta_start,
37
+ beta_end,
38
+ beta_schedule,
39
+ trained_betas,
40
+ prediction_type,
41
+ interpolation_type,
42
+ use_karras_sigmas,
43
+ timestep_spacing,
44
+ steps_offset,
45
+ )
46
+
47
+ def step(
48
+ self,
49
+ model_output: torch.FloatTensor,
50
+ timestep: Union[float, torch.FloatTensor],
51
+ sample: torch.FloatTensor,
52
+ s_churn: float = 0.0,
53
+ s_tmin: float = 0.0,
54
+ s_tmax: float = float("inf"),
55
+ s_noise: float = 1.0,
56
+ generator: Optional[torch.Generator] = None,
57
+ return_dict: bool = True,
58
+ w_ind_noise: float = 0.5,
59
+ noise_type: str = "random",
60
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
61
+ """
62
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
63
+ process from the learned model outputs (most often the predicted noise).
64
+
65
+ Args:
66
+ model_output (`torch.FloatTensor`):
67
+ The direct output from learned diffusion model.
68
+ timestep (`float`):
69
+ The current discrete timestep in the diffusion chain.
70
+ sample (`torch.FloatTensor`):
71
+ A current instance of a sample created by the diffusion process.
72
+ s_churn (`float`):
73
+ s_tmin (`float`):
74
+ s_tmax (`float`):
75
+ s_noise (`float`, defaults to 1.0):
76
+ Scaling factor for noise added to the sample.
77
+ generator (`torch.Generator`, *optional*):
78
+ A random number generator.
79
+ return_dict (`bool`):
80
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
81
+ tuple.
82
+
83
+ Returns:
84
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
85
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
86
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
87
+ """
88
+
89
+ if (
90
+ isinstance(timestep, int)
91
+ or isinstance(timestep, torch.IntTensor)
92
+ or isinstance(timestep, torch.LongTensor)
93
+ ):
94
+ raise ValueError(
95
+ (
96
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
97
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
98
+ " one of the `scheduler.timesteps` as a timestep."
99
+ ),
100
+ )
101
+
102
+ if not self.is_scale_input_called:
103
+ logger.warning(
104
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
105
+ "See `StableDiffusionPipeline` for a usage example."
106
+ )
107
+
108
+ if self.step_index is None:
109
+ self._init_step_index(timestep)
110
+
111
+ sigma = self.sigmas[self.step_index]
112
+
113
+ gamma = (
114
+ min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
115
+ if s_tmin <= sigma <= s_tmax
116
+ else 0.0
117
+ )
118
+ device = model_output.device
119
+
120
+ if noise_type == "random":
121
+ noise = randn_tensor(
122
+ model_output.shape,
123
+ dtype=model_output.dtype,
124
+ device=device,
125
+ generator=generator,
126
+ )
127
+ elif noise_type == "video_fusion":
128
+ noise = video_fusion_noise(
129
+ model_output, w_ind_noise=w_ind_noise, generator=generator
130
+ )
131
+
132
+ eps = noise * s_noise
133
+ sigma_hat = sigma * (gamma + 1)
134
+
135
+ if gamma > 0:
136
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
137
+
138
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
139
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
140
+ # backwards compatibility
141
+ if (
142
+ self.config.prediction_type == "original_sample"
143
+ or self.config.prediction_type == "sample"
144
+ ):
145
+ pred_original_sample = model_output
146
+ elif self.config.prediction_type == "epsilon":
147
+ pred_original_sample = sample - sigma_hat * model_output
148
+ elif self.config.prediction_type == "v_prediction":
149
+ # * c_out + input * c_skip
150
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
151
+ sample / (sigma**2 + 1)
152
+ )
153
+ else:
154
+ raise ValueError(
155
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
156
+ )
157
+
158
+ # 2. Convert to an ODE derivative
159
+ derivative = (sample - pred_original_sample) / sigma_hat
160
+
161
+ dt = self.sigmas[self.step_index + 1] - sigma_hat
162
+
163
+ prev_sample = sample + derivative * dt
164
+
165
+ # upon completion increase step index by one
166
+ self._step_index += 1
167
+
168
+ if not return_dict:
169
+ return (prev_sample,)
170
+
171
+ return EulerDiscreteSchedulerOutput(
172
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
173
+ )
174
+
175
+ def step_bk(
176
+ self,
177
+ model_output: FloatTensor,
178
+ timestep: float | FloatTensor,
179
+ sample: FloatTensor,
180
+ s_churn: float = 0,
181
+ s_tmin: float = 0,
182
+ s_tmax: float = float("inf"),
183
+ s_noise: float = 1,
184
+ generator: Generator | None = None,
185
+ return_dict: bool = True,
186
+ w_ind_noise: float = 0.5,
187
+ noise_type: str = "random",
188
+ ) -> EulerDiscreteSchedulerOutput | Tuple:
189
+ """
190
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
191
+ process from the learned model outputs (most often the predicted noise).
192
+
193
+ Args:
194
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
195
+ timestep (`float`): current timestep in the diffusion chain.
196
+ sample (`torch.FloatTensor`):
197
+ current instance of sample being created by diffusion process.
198
+ s_churn (`float`)
199
+ s_tmin (`float`)
200
+ s_tmax (`float`)
201
+ s_noise (`float`)
202
+ generator (`torch.Generator`, optional): Random number generator.
203
+ return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
204
+
205
+ Returns:
206
+ [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`:
207
+ [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a
208
+ `tuple`. When returning a tuple, the first element is the sample tensor.
209
+
210
+ """
211
+
212
+ if (
213
+ isinstance(timestep, int)
214
+ or isinstance(timestep, torch.IntTensor)
215
+ or isinstance(timestep, torch.LongTensor)
216
+ ):
217
+ raise ValueError(
218
+ (
219
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
220
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
221
+ " one of the `scheduler.timesteps` as a timestep."
222
+ ),
223
+ )
224
+
225
+ if not self.is_scale_input_called:
226
+ logger.warning(
227
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
228
+ "See `StableDiffusionPipeline` for a usage example."
229
+ )
230
+
231
+ if isinstance(timestep, torch.Tensor):
232
+ timestep = timestep.to(self.timesteps.device)
233
+
234
+ step_index = (self.timesteps == timestep).nonzero().item()
235
+ sigma = self.sigmas[step_index]
236
+
237
+ gamma = (
238
+ min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
239
+ if s_tmin <= sigma <= s_tmax
240
+ else 0.0
241
+ )
242
+
243
+ device = model_output.device
244
+ if noise_type == "random":
245
+ noise = randn_tensor(
246
+ model_output.shape,
247
+ dtype=model_output.dtype,
248
+ device=device,
249
+ generator=generator,
250
+ )
251
+ elif noise_type == "video_fusion":
252
+ noise = video_fusion_noise(
253
+ model_output, w_ind_noise=w_ind_noise, generator=generator
254
+ )
255
+ eps = noise * s_noise
256
+ sigma_hat = sigma * (gamma + 1)
257
+
258
+ if gamma > 0:
259
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
260
+
261
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
262
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
263
+ # backwards compatibility
264
+ if (
265
+ self.config.prediction_type == "original_sample"
266
+ or self.config.prediction_type == "sample"
267
+ ):
268
+ pred_original_sample = model_output
269
+ elif self.config.prediction_type == "epsilon":
270
+ pred_original_sample = sample - sigma_hat * model_output
271
+ elif self.config.prediction_type == "v_prediction":
272
+ # * c_out + input * c_skip
273
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
274
+ sample / (sigma**2 + 1)
275
+ )
276
+ else:
277
+ raise ValueError(
278
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
279
+ )
280
+
281
+ # 2. Convert to an ODE derivative
282
+ derivative = (sample - pred_original_sample) / sigma_hat
283
+
284
+ dt = self.sigmas[step_index + 1] - sigma_hat
285
+
286
+ prev_sample = sample + derivative * dt
287
+
288
+ if not return_dict:
289
+ return (prev_sample,)
290
+
291
+ return EulerDiscreteSchedulerOutput(
292
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
293
+ )
musev/schedulers/scheduling_lcm.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ from numpy import ndarray
26
+
27
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
28
+ from diffusers.utils import BaseOutput, logging
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
31
+ from diffusers.schedulers.scheduling_lcm import (
32
+ LCMSchedulerOutput,
33
+ betas_for_alpha_bar,
34
+ rescale_zero_terminal_snr,
35
+ LCMScheduler as DiffusersLCMScheduler,
36
+ )
37
+ from ..utils.noise_util import video_fusion_noise
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ class LCMScheduler(DiffusersLCMScheduler):
43
+ def __init__(
44
+ self,
45
+ num_train_timesteps: int = 1000,
46
+ beta_start: float = 0.00085,
47
+ beta_end: float = 0.012,
48
+ beta_schedule: str = "scaled_linear",
49
+ trained_betas: ndarray | List[float] | None = None,
50
+ original_inference_steps: int = 50,
51
+ clip_sample: bool = False,
52
+ clip_sample_range: float = 1,
53
+ set_alpha_to_one: bool = True,
54
+ steps_offset: int = 0,
55
+ prediction_type: str = "epsilon",
56
+ thresholding: bool = False,
57
+ dynamic_thresholding_ratio: float = 0.995,
58
+ sample_max_value: float = 1,
59
+ timestep_spacing: str = "leading",
60
+ timestep_scaling: float = 10,
61
+ rescale_betas_zero_snr: bool = False,
62
+ ):
63
+ super().__init__(
64
+ num_train_timesteps,
65
+ beta_start,
66
+ beta_end,
67
+ beta_schedule,
68
+ trained_betas,
69
+ original_inference_steps,
70
+ clip_sample,
71
+ clip_sample_range,
72
+ set_alpha_to_one,
73
+ steps_offset,
74
+ prediction_type,
75
+ thresholding,
76
+ dynamic_thresholding_ratio,
77
+ sample_max_value,
78
+ timestep_spacing,
79
+ timestep_scaling,
80
+ rescale_betas_zero_snr,
81
+ )
82
+
83
+ def step(
84
+ self,
85
+ model_output: torch.FloatTensor,
86
+ timestep: int,
87
+ sample: torch.FloatTensor,
88
+ generator: Optional[torch.Generator] = None,
89
+ return_dict: bool = True,
90
+ w_ind_noise: float = 0.5,
91
+ noise_type: str = "random",
92
+ ) -> Union[LCMSchedulerOutput, Tuple]:
93
+ """
94
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
95
+ process from the learned model outputs (most often the predicted noise).
96
+
97
+ Args:
98
+ model_output (`torch.FloatTensor`):
99
+ The direct output from learned diffusion model.
100
+ timestep (`float`):
101
+ The current discrete timestep in the diffusion chain.
102
+ sample (`torch.FloatTensor`):
103
+ A current instance of a sample created by the diffusion process.
104
+ generator (`torch.Generator`, *optional*):
105
+ A random number generator.
106
+ return_dict (`bool`, *optional*, defaults to `True`):
107
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
108
+ Returns:
109
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
110
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
111
+ tuple is returned where the first element is the sample tensor.
112
+ """
113
+ if self.num_inference_steps is None:
114
+ raise ValueError(
115
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
116
+ )
117
+
118
+ if self.step_index is None:
119
+ self._init_step_index(timestep)
120
+
121
+ # 1. get previous step value
122
+ prev_step_index = self.step_index + 1
123
+ if prev_step_index < len(self.timesteps):
124
+ prev_timestep = self.timesteps[prev_step_index]
125
+ else:
126
+ prev_timestep = timestep
127
+
128
+ # 2. compute alphas, betas
129
+ alpha_prod_t = self.alphas_cumprod[timestep]
130
+ alpha_prod_t_prev = (
131
+ self.alphas_cumprod[prev_timestep]
132
+ if prev_timestep >= 0
133
+ else self.final_alpha_cumprod
134
+ )
135
+
136
+ beta_prod_t = 1 - alpha_prod_t
137
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
138
+
139
+ # 3. Get scalings for boundary conditions
140
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
141
+
142
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
143
+ if self.config.prediction_type == "epsilon": # noise-prediction
144
+ predicted_original_sample = (
145
+ sample - beta_prod_t.sqrt() * model_output
146
+ ) / alpha_prod_t.sqrt()
147
+ elif self.config.prediction_type == "sample": # x-prediction
148
+ predicted_original_sample = model_output
149
+ elif self.config.prediction_type == "v_prediction": # v-prediction
150
+ predicted_original_sample = (
151
+ alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
152
+ )
153
+ else:
154
+ raise ValueError(
155
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
156
+ " `v_prediction` for `LCMScheduler`."
157
+ )
158
+
159
+ # 5. Clip or threshold "predicted x_0"
160
+ if self.config.thresholding:
161
+ predicted_original_sample = self._threshold_sample(
162
+ predicted_original_sample
163
+ )
164
+ elif self.config.clip_sample:
165
+ predicted_original_sample = predicted_original_sample.clamp(
166
+ -self.config.clip_sample_range, self.config.clip_sample_range
167
+ )
168
+
169
+ # 6. Denoise model output using boundary conditions
170
+ denoised = c_out * predicted_original_sample + c_skip * sample
171
+
172
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
173
+ # Noise is not used on the final timestep of the timestep schedule.
174
+ # This also means that noise is not used for one-step sampling.
175
+ device = model_output.device
176
+
177
+ if self.step_index != self.num_inference_steps - 1:
178
+ if noise_type == "random":
179
+ noise = randn_tensor(
180
+ model_output.shape,
181
+ dtype=model_output.dtype,
182
+ device=device,
183
+ generator=generator,
184
+ )
185
+ elif noise_type == "video_fusion":
186
+ noise = video_fusion_noise(
187
+ model_output, w_ind_noise=w_ind_noise, generator=generator
188
+ )
189
+ prev_sample = (
190
+ alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
191
+ )
192
+ else:
193
+ prev_sample = denoised
194
+
195
+ # upon completion increase step index by one
196
+ self._step_index += 1
197
+
198
+ if not return_dict:
199
+ return (prev_sample, denoised)
200
+
201
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
202
+
203
+ def step_bk(
204
+ self,
205
+ model_output: torch.FloatTensor,
206
+ timestep: int,
207
+ sample: torch.FloatTensor,
208
+ generator: Optional[torch.Generator] = None,
209
+ return_dict: bool = True,
210
+ ) -> Union[LCMSchedulerOutput, Tuple]:
211
+ """
212
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
213
+ process from the learned model outputs (most often the predicted noise).
214
+
215
+ Args:
216
+ model_output (`torch.FloatTensor`):
217
+ The direct output from learned diffusion model.
218
+ timestep (`float`):
219
+ The current discrete timestep in the diffusion chain.
220
+ sample (`torch.FloatTensor`):
221
+ A current instance of a sample created by the diffusion process.
222
+ generator (`torch.Generator`, *optional*):
223
+ A random number generator.
224
+ return_dict (`bool`, *optional*, defaults to `True`):
225
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
226
+ Returns:
227
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
228
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
229
+ tuple is returned where the first element is the sample tensor.
230
+ """
231
+ if self.num_inference_steps is None:
232
+ raise ValueError(
233
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
234
+ )
235
+
236
+ if self.step_index is None:
237
+ self._init_step_index(timestep)
238
+
239
+ # 1. get previous step value
240
+ prev_step_index = self.step_index + 1
241
+ if prev_step_index < len(self.timesteps):
242
+ prev_timestep = self.timesteps[prev_step_index]
243
+ else:
244
+ prev_timestep = timestep
245
+
246
+ # 2. compute alphas, betas
247
+ alpha_prod_t = self.alphas_cumprod[timestep]
248
+ alpha_prod_t_prev = (
249
+ self.alphas_cumprod[prev_timestep]
250
+ if prev_timestep >= 0
251
+ else self.final_alpha_cumprod
252
+ )
253
+
254
+ beta_prod_t = 1 - alpha_prod_t
255
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
256
+
257
+ # 3. Get scalings for boundary conditions
258
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
259
+
260
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
261
+ if self.config.prediction_type == "epsilon": # noise-prediction
262
+ predicted_original_sample = (
263
+ sample - beta_prod_t.sqrt() * model_output
264
+ ) / alpha_prod_t.sqrt()
265
+ elif self.config.prediction_type == "sample": # x-prediction
266
+ predicted_original_sample = model_output
267
+ elif self.config.prediction_type == "v_prediction": # v-prediction
268
+ predicted_original_sample = (
269
+ alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
270
+ )
271
+ else:
272
+ raise ValueError(
273
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
274
+ " `v_prediction` for `LCMScheduler`."
275
+ )
276
+
277
+ # 5. Clip or threshold "predicted x_0"
278
+ if self.config.thresholding:
279
+ predicted_original_sample = self._threshold_sample(
280
+ predicted_original_sample
281
+ )
282
+ elif self.config.clip_sample:
283
+ predicted_original_sample = predicted_original_sample.clamp(
284
+ -self.config.clip_sample_range, self.config.clip_sample_range
285
+ )
286
+
287
+ # 6. Denoise model output using boundary conditions
288
+ denoised = c_out * predicted_original_sample + c_skip * sample
289
+
290
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
291
+ # Noise is not used on the final timestep of the timestep schedule.
292
+ # This also means that noise is not used for one-step sampling.
293
+ if self.step_index != self.num_inference_steps - 1:
294
+ noise = randn_tensor(
295
+ model_output.shape,
296
+ generator=generator,
297
+ device=model_output.device,
298
+ dtype=denoised.dtype,
299
+ )
300
+ prev_sample = (
301
+ alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
302
+ )
303
+ else:
304
+ prev_sample = denoised
305
+
306
+ # upon completion increase step index by one
307
+ self._step_index += 1
308
+
309
+ if not return_dict:
310
+ return (prev_sample, denoised)
311
+
312
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
musev/utils/__init__.py ADDED
File without changes
musev/utils/attention_util.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union, Literal
2
+
3
+ from einops import repeat
4
+ import torch
5
+ import numpy as np
6
+
7
+
8
+ def get_diags_indices(
9
+ shape: Union[int, Tuple[int, int]], k_min: int = 0, k_max: int = 0
10
+ ):
11
+ if isinstance(shape, int):
12
+ shape = (shape, shape)
13
+ rows, cols = np.indices(shape)
14
+ diag = cols - rows
15
+ return np.where((diag >= k_min) & (diag <= k_max))
16
+
17
+
18
+ def generate_mask_from_indices(
19
+ shape: Tuple[int, int],
20
+ indices: Tuple[np.ndarray, np.ndarray],
21
+ big_value: float = 0,
22
+ small_value: float = -1e9,
23
+ ):
24
+ matrix = np.ones(shape) * small_value
25
+ matrix[indices] = big_value
26
+ return matrix
27
+
28
+
29
+ def generate_sparse_causcal_attn_mask(
30
+ batch_size: int,
31
+ n: int,
32
+ n_near: int = 1,
33
+ big_value: float = 0,
34
+ small_value: float = -1e9,
35
+ out_type: Literal["torch", "numpy"] = "numpy",
36
+ expand: int = 1,
37
+ ) -> np.ndarray:
38
+ """generate b (n expand) (n expand) mask,
39
+ where value of diag (0<=<=n_near) and first column of shape mat (n n) is set as big_value, others as small value
40
+ expand的概念:
41
+ attn 是 b n d 时,mask 是 b n n, 当 attn 是 b (expand n) d 时, mask 是 b (n expand) (n expand)
42
+ Args:
43
+ batch_size (int): _description_
44
+ n (int): _description_
45
+ n_near (int, optional): _description_. Defaults to 1.
46
+ big_value (float, optional): _description_. Defaults to 0.
47
+ small_value (float, optional): _description_. Defaults to -1e9.
48
+ out_type (Literal[&quot;torch&quot;, &quot;numpy&quot;], optional): _description_. Defaults to "numpy".
49
+ expand (int, optional): _description_. Defaults to 1.
50
+
51
+ Returns:
52
+ np.ndarray: _description_
53
+ """
54
+ shape = (n, n)
55
+ diag_indices = get_diags_indices(n, k_min=-n_near, k_max=0)
56
+ first_column = (np.arange(n), np.zeros(n).astype(np.int))
57
+ indices = (
58
+ np.concatenate([diag_indices[0], first_column[0]]),
59
+ np.concatenate([diag_indices[1], first_column[1]]),
60
+ )
61
+ mask = generate_mask_from_indices(
62
+ shape=shape, indices=indices, big_value=big_value, small_value=small_value
63
+ )
64
+ mask = repeat(mask, "m n-> b m n", b=batch_size)
65
+ if expand > 1:
66
+ mask = repeat(
67
+ mask,
68
+ "b m n -> b (m d1) (n d2)",
69
+ d1=expand,
70
+ d2=expand,
71
+ )
72
+ if out_type == "torch":
73
+ mask = torch.from_numpy(mask)
74
+ return mask
musev/utils/convert_from_ckpt.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import re
18
+ from io import BytesIO
19
+ from typing import Optional
20
+
21
+ import requests
22
+ import torch
23
+ from transformers import (
24
+ AutoFeatureExtractor,
25
+ BertTokenizerFast,
26
+ CLIPImageProcessor,
27
+ CLIPTextModel,
28
+ CLIPTextModelWithProjection,
29
+ CLIPTokenizer,
30
+ CLIPVisionConfig,
31
+ CLIPVisionModelWithProjection,
32
+ )
33
+
34
+ from diffusers.models import (
35
+ AutoencoderKL,
36
+ PriorTransformer,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.schedulers import (
40
+ DDIMScheduler,
41
+ DDPMScheduler,
42
+ DPMSolverMultistepScheduler,
43
+ EulerAncestralDiscreteScheduler,
44
+ EulerDiscreteScheduler,
45
+ HeunDiscreteScheduler,
46
+ LMSDiscreteScheduler,
47
+ PNDMScheduler,
48
+ UnCLIPScheduler,
49
+ )
50
+ from diffusers.utils.import_utils import BACKENDS_MAPPING
51
+
52
+
53
+ def shave_segments(path, n_shave_prefix_segments=1):
54
+ """
55
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
56
+ """
57
+ if n_shave_prefix_segments >= 0:
58
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
59
+ else:
60
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
61
+
62
+
63
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
64
+ """
65
+ Updates paths inside resnets to the new naming scheme (local renaming)
66
+ """
67
+ mapping = []
68
+ for old_item in old_list:
69
+ new_item = old_item.replace("in_layers.0", "norm1")
70
+ new_item = new_item.replace("in_layers.2", "conv1")
71
+
72
+ new_item = new_item.replace("out_layers.0", "norm2")
73
+ new_item = new_item.replace("out_layers.3", "conv2")
74
+
75
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
76
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
77
+
78
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
79
+
80
+ mapping.append({"old": old_item, "new": new_item})
81
+
82
+ return mapping
83
+
84
+
85
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
86
+ """
87
+ Updates paths inside resnets to the new naming scheme (local renaming)
88
+ """
89
+ mapping = []
90
+ for old_item in old_list:
91
+ new_item = old_item
92
+
93
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
94
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
95
+
96
+ mapping.append({"old": old_item, "new": new_item})
97
+
98
+ return mapping
99
+
100
+
101
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
102
+ """
103
+ Updates paths inside attentions to the new naming scheme (local renaming)
104
+ """
105
+ mapping = []
106
+ for old_item in old_list:
107
+ new_item = old_item
108
+
109
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
110
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
111
+
112
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
113
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
114
+
115
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
116
+
117
+ mapping.append({"old": old_item, "new": new_item})
118
+
119
+ return mapping
120
+
121
+
122
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
123
+ """
124
+ Updates paths inside attentions to the new naming scheme (local renaming)
125
+ """
126
+ mapping = []
127
+ for old_item in old_list:
128
+ new_item = old_item
129
+
130
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
131
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
132
+
133
+ new_item = new_item.replace("q.weight", "to_q.weight")
134
+ new_item = new_item.replace("q.bias", "to_q.bias")
135
+
136
+ new_item = new_item.replace("k.weight", "to_k.weight")
137
+ new_item = new_item.replace("k.bias", "to_k.bias")
138
+
139
+ new_item = new_item.replace("v.weight", "to_v.weight")
140
+ new_item = new_item.replace("v.bias", "to_v.bias")
141
+
142
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
143
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
144
+
145
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
146
+
147
+ mapping.append({"old": old_item, "new": new_item})
148
+
149
+ return mapping
150
+
151
+
152
+ def assign_to_checkpoint(
153
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
154
+ ):
155
+ """
156
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
157
+ attention layers, and takes into account additional replacements that may arise.
158
+
159
+ Assigns the weights to the new checkpoint.
160
+ """
161
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
162
+
163
+ # Splits the attention layers into three variables.
164
+ if attention_paths_to_split is not None:
165
+ for path, path_map in attention_paths_to_split.items():
166
+ old_tensor = old_checkpoint[path]
167
+ channels = old_tensor.shape[0] // 3
168
+
169
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
170
+
171
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
172
+
173
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
174
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
175
+
176
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
177
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
178
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
179
+
180
+ for path in paths:
181
+ new_path = path["new"]
182
+
183
+ # These have already been assigned
184
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
185
+ continue
186
+
187
+ # Global renaming happens here
188
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
189
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
190
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
191
+
192
+ if additional_replacements is not None:
193
+ for replacement in additional_replacements:
194
+ new_path = new_path.replace(replacement["old"], replacement["new"])
195
+
196
+ # proj_attn.weight has to be converted from conv 1D to linear
197
+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
198
+ shape = old_checkpoint[path["old"]].shape
199
+ if is_attn_weight and len(shape) == 3:
200
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
201
+ elif is_attn_weight and len(shape) == 4:
202
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
203
+ else:
204
+ checkpoint[new_path] = old_checkpoint[path["old"]]
205
+
206
+
207
+ def conv_attn_to_linear(checkpoint):
208
+ keys = list(checkpoint.keys())
209
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
210
+ for key in keys:
211
+ if ".".join(key.split(".")[-2:]) in attn_keys:
212
+ if checkpoint[key].ndim > 2:
213
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
214
+ elif "proj_attn.weight" in key:
215
+ if checkpoint[key].ndim > 2:
216
+ checkpoint[key] = checkpoint[key][:, :, 0]
217
+
218
+
219
+ def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
220
+ """
221
+ Creates a config for the diffusers based on the config of the LDM model.
222
+ """
223
+ if controlnet:
224
+ unet_params = original_config.model.params.control_stage_config.params
225
+ else:
226
+ unet_params = original_config.model.params.unet_config.params
227
+
228
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
229
+
230
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
231
+
232
+ down_block_types = []
233
+ resolution = 1
234
+ for i in range(len(block_out_channels)):
235
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
236
+ down_block_types.append(block_type)
237
+ if i != len(block_out_channels) - 1:
238
+ resolution *= 2
239
+
240
+ up_block_types = []
241
+ for i in range(len(block_out_channels)):
242
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
243
+ up_block_types.append(block_type)
244
+ resolution //= 2
245
+
246
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
247
+
248
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
249
+ use_linear_projection = (
250
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
251
+ )
252
+ if use_linear_projection:
253
+ # stable diffusion 2-base-512 and 2-768
254
+ if head_dim is None:
255
+ head_dim = [5, 10, 20, 20]
256
+
257
+ class_embed_type = None
258
+ projection_class_embeddings_input_dim = None
259
+
260
+ if "num_classes" in unet_params:
261
+ if unet_params.num_classes == "sequential":
262
+ class_embed_type = "projection"
263
+ assert "adm_in_channels" in unet_params
264
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
265
+ else:
266
+ raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
267
+
268
+ config = {
269
+ "sample_size": image_size // vae_scale_factor,
270
+ "in_channels": unet_params.in_channels,
271
+ "down_block_types": tuple(down_block_types),
272
+ "block_out_channels": tuple(block_out_channels),
273
+ "layers_per_block": unet_params.num_res_blocks,
274
+ "cross_attention_dim": unet_params.context_dim,
275
+ "attention_head_dim": head_dim,
276
+ "use_linear_projection": use_linear_projection,
277
+ "class_embed_type": class_embed_type,
278
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
279
+ }
280
+
281
+ if not controlnet:
282
+ config["out_channels"] = unet_params.out_channels
283
+ config["up_block_types"] = tuple(up_block_types)
284
+
285
+ return config
286
+
287
+
288
+ def create_vae_diffusers_config(original_config, image_size: int):
289
+ """
290
+ Creates a config for the diffusers based on the config of the LDM model.
291
+ """
292
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
293
+ _ = original_config.model.params.first_stage_config.params.embed_dim
294
+
295
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
296
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
297
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
298
+
299
+ config = {
300
+ "sample_size": image_size,
301
+ "in_channels": vae_params.in_channels,
302
+ "out_channels": vae_params.out_ch,
303
+ "down_block_types": tuple(down_block_types),
304
+ "up_block_types": tuple(up_block_types),
305
+ "block_out_channels": tuple(block_out_channels),
306
+ "latent_channels": vae_params.z_channels,
307
+ "layers_per_block": vae_params.num_res_blocks,
308
+ }
309
+ return config
310
+
311
+
312
+ def create_diffusers_schedular(original_config):
313
+ schedular = DDIMScheduler(
314
+ num_train_timesteps=original_config.model.params.timesteps,
315
+ beta_start=original_config.model.params.linear_start,
316
+ beta_end=original_config.model.params.linear_end,
317
+ beta_schedule="scaled_linear",
318
+ )
319
+ return schedular
320
+
321
+
322
+ def create_ldm_bert_config(original_config):
323
+ bert_params = original_config.model.parms.cond_stage_config.params
324
+ config = LDMBertConfig(
325
+ d_model=bert_params.n_embed,
326
+ encoder_layers=bert_params.n_layer,
327
+ encoder_ffn_dim=bert_params.n_embed * 4,
328
+ )
329
+ return config
330
+
331
+
332
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
333
+ """
334
+ Takes a state dict and a config, and returns a converted checkpoint.
335
+ """
336
+
337
+ # extract state_dict for UNet
338
+ unet_state_dict = {}
339
+ keys = list(checkpoint.keys())
340
+
341
+ if controlnet:
342
+ unet_key = "control_model."
343
+ else:
344
+ unet_key = "model.diffusion_model."
345
+
346
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
347
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
348
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
349
+ print(
350
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
351
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
352
+ )
353
+ for key in keys:
354
+ if key.startswith("model.diffusion_model"):
355
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
356
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
357
+ else:
358
+ if sum(k.startswith("model_ema") for k in keys) > 100:
359
+ print(
360
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
361
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
362
+ )
363
+
364
+ for key in keys:
365
+ if key.startswith(unet_key):
366
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
367
+
368
+ new_checkpoint = {}
369
+
370
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
371
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
372
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
373
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
374
+
375
+ if config["class_embed_type"] is None:
376
+ # No parameters to port
377
+ ...
378
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
379
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
380
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
381
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
382
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
383
+ else:
384
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
385
+
386
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
387
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
388
+
389
+ if not controlnet:
390
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
391
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
392
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
393
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
394
+
395
+ # Retrieves the keys for the input blocks only
396
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
397
+ input_blocks = {
398
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
399
+ for layer_id in range(num_input_blocks)
400
+ }
401
+
402
+ # Retrieves the keys for the middle blocks only
403
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
404
+ middle_blocks = {
405
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
406
+ for layer_id in range(num_middle_blocks)
407
+ }
408
+
409
+ # Retrieves the keys for the output blocks only
410
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
411
+ output_blocks = {
412
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
413
+ for layer_id in range(num_output_blocks)
414
+ }
415
+
416
+ for i in range(1, num_input_blocks):
417
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
418
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
419
+
420
+ resnets = [
421
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
422
+ ]
423
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
424
+
425
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
426
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
427
+ f"input_blocks.{i}.0.op.weight"
428
+ )
429
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
430
+ f"input_blocks.{i}.0.op.bias"
431
+ )
432
+
433
+ paths = renew_resnet_paths(resnets)
434
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
435
+ assign_to_checkpoint(
436
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
437
+ )
438
+
439
+ if len(attentions):
440
+ paths = renew_attention_paths(attentions)
441
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
442
+ assign_to_checkpoint(
443
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
444
+ )
445
+
446
+ resnet_0 = middle_blocks[0]
447
+ attentions = middle_blocks[1]
448
+ resnet_1 = middle_blocks[2]
449
+
450
+ resnet_0_paths = renew_resnet_paths(resnet_0)
451
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
452
+
453
+ resnet_1_paths = renew_resnet_paths(resnet_1)
454
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
455
+
456
+ attentions_paths = renew_attention_paths(attentions)
457
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
458
+ assign_to_checkpoint(
459
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
460
+ )
461
+
462
+ for i in range(num_output_blocks):
463
+ block_id = i // (config["layers_per_block"] + 1)
464
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
465
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
466
+ output_block_list = {}
467
+
468
+ for layer in output_block_layers:
469
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
470
+ if layer_id in output_block_list:
471
+ output_block_list[layer_id].append(layer_name)
472
+ else:
473
+ output_block_list[layer_id] = [layer_name]
474
+
475
+ if len(output_block_list) > 1:
476
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
477
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
478
+
479
+ resnet_0_paths = renew_resnet_paths(resnets)
480
+ paths = renew_resnet_paths(resnets)
481
+
482
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
483
+ assign_to_checkpoint(
484
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
485
+ )
486
+
487
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
488
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
489
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
490
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
491
+ f"output_blocks.{i}.{index}.conv.weight"
492
+ ]
493
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
494
+ f"output_blocks.{i}.{index}.conv.bias"
495
+ ]
496
+
497
+ # Clear attentions as they have been attributed above.
498
+ if len(attentions) == 2:
499
+ attentions = []
500
+
501
+ if len(attentions):
502
+ paths = renew_attention_paths(attentions)
503
+ meta_path = {
504
+ "old": f"output_blocks.{i}.1",
505
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
506
+ }
507
+ assign_to_checkpoint(
508
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
509
+ )
510
+ else:
511
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
512
+ for path in resnet_0_paths:
513
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
514
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
515
+
516
+ new_checkpoint[new_path] = unet_state_dict[old_path]
517
+
518
+ if controlnet:
519
+ # conditioning embedding
520
+
521
+ orig_index = 0
522
+
523
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
524
+ f"input_hint_block.{orig_index}.weight"
525
+ )
526
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
527
+ f"input_hint_block.{orig_index}.bias"
528
+ )
529
+
530
+ orig_index += 2
531
+
532
+ diffusers_index = 0
533
+
534
+ while diffusers_index < 6:
535
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
536
+ f"input_hint_block.{orig_index}.weight"
537
+ )
538
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
539
+ f"input_hint_block.{orig_index}.bias"
540
+ )
541
+ diffusers_index += 1
542
+ orig_index += 2
543
+
544
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
545
+ f"input_hint_block.{orig_index}.weight"
546
+ )
547
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
548
+ f"input_hint_block.{orig_index}.bias"
549
+ )
550
+
551
+ # down blocks
552
+ for i in range(num_input_blocks):
553
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
554
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
555
+
556
+ # mid block
557
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
558
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
559
+
560
+ return new_checkpoint
561
+
562
+
563
+ def convert_ldm_vae_checkpoint(checkpoint, config):
564
+ # extract state dict for VAE
565
+ vae_state_dict = {}
566
+ vae_key = "first_stage_model."
567
+ keys = list(checkpoint.keys())
568
+ for key in keys:
569
+ if key.startswith(vae_key):
570
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
571
+
572
+ new_checkpoint = {}
573
+
574
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
575
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
576
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
577
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
578
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
579
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
580
+
581
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
582
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
583
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
584
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
585
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
586
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
587
+
588
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
589
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
590
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
591
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
592
+
593
+ # Retrieves the keys for the encoder down blocks only
594
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
595
+ down_blocks = {
596
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
597
+ }
598
+
599
+ # Retrieves the keys for the decoder up blocks only
600
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
601
+ up_blocks = {
602
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
603
+ }
604
+
605
+ for i in range(num_down_blocks):
606
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
607
+
608
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
609
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
610
+ f"encoder.down.{i}.downsample.conv.weight"
611
+ )
612
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
613
+ f"encoder.down.{i}.downsample.conv.bias"
614
+ )
615
+
616
+ paths = renew_vae_resnet_paths(resnets)
617
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
618
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
619
+
620
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
621
+ num_mid_res_blocks = 2
622
+ for i in range(1, num_mid_res_blocks + 1):
623
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
624
+
625
+ paths = renew_vae_resnet_paths(resnets)
626
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
627
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
628
+
629
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
630
+ paths = renew_vae_attention_paths(mid_attentions)
631
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
632
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
633
+ conv_attn_to_linear(new_checkpoint)
634
+
635
+ for i in range(num_up_blocks):
636
+ block_id = num_up_blocks - 1 - i
637
+ resnets = [
638
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
639
+ ]
640
+
641
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
642
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
643
+ f"decoder.up.{block_id}.upsample.conv.weight"
644
+ ]
645
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
646
+ f"decoder.up.{block_id}.upsample.conv.bias"
647
+ ]
648
+
649
+ paths = renew_vae_resnet_paths(resnets)
650
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
651
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
652
+
653
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
654
+ num_mid_res_blocks = 2
655
+ for i in range(1, num_mid_res_blocks + 1):
656
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
657
+
658
+ paths = renew_vae_resnet_paths(resnets)
659
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
660
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
661
+
662
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
663
+ paths = renew_vae_attention_paths(mid_attentions)
664
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
665
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
666
+ conv_attn_to_linear(new_checkpoint)
667
+ return new_checkpoint
668
+
669
+
670
+ def convert_ldm_bert_checkpoint(checkpoint, config):
671
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
672
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
673
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
674
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
675
+
676
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
677
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
678
+
679
+ def _copy_linear(hf_linear, pt_linear):
680
+ hf_linear.weight = pt_linear.weight
681
+ hf_linear.bias = pt_linear.bias
682
+
683
+ def _copy_layer(hf_layer, pt_layer):
684
+ # copy layer norms
685
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
686
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
687
+
688
+ # copy attn
689
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
690
+
691
+ # copy MLP
692
+ pt_mlp = pt_layer[1][1]
693
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
694
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
695
+
696
+ def _copy_layers(hf_layers, pt_layers):
697
+ for i, hf_layer in enumerate(hf_layers):
698
+ if i != 0:
699
+ i += i
700
+ pt_layer = pt_layers[i : i + 2]
701
+ _copy_layer(hf_layer, pt_layer)
702
+
703
+ hf_model = LDMBertModel(config).eval()
704
+
705
+ # copy embeds
706
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
707
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
708
+
709
+ # copy layer norm
710
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
711
+
712
+ # copy hidden layers
713
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
714
+
715
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
716
+
717
+ return hf_model
718
+
719
+
720
+ def convert_ldm_clip_checkpoint(checkpoint, pretrained_model_path):
721
+ text_model = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
722
+ keys = list(checkpoint.keys())
723
+
724
+ text_model_dict = {}
725
+
726
+ for key in keys:
727
+ if key.startswith("cond_stage_model.transformer"):
728
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
729
+
730
+ text_model.load_state_dict(text_model_dict)
731
+
732
+ return text_model
733
+
734
+
735
+ textenc_conversion_lst = [
736
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
737
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
738
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
739
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
740
+ ]
741
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
742
+
743
+ textenc_transformer_conversion_lst = [
744
+ # (stable-diffusion, HF Diffusers)
745
+ ("resblocks.", "text_model.encoder.layers."),
746
+ ("ln_1", "layer_norm1"),
747
+ ("ln_2", "layer_norm2"),
748
+ (".c_fc.", ".fc1."),
749
+ (".c_proj.", ".fc2."),
750
+ (".attn", ".self_attn"),
751
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
752
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
753
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
754
+ ]
755
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
756
+ textenc_pattern = re.compile("|".join(protected.keys()))
757
+
758
+
759
+ def convert_paint_by_example_checkpoint(checkpoint):
760
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
761
+ model = PaintByExampleImageEncoder(config)
762
+
763
+ keys = list(checkpoint.keys())
764
+
765
+ text_model_dict = {}
766
+
767
+ for key in keys:
768
+ if key.startswith("cond_stage_model.transformer"):
769
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
770
+
771
+ # load clip vision
772
+ model.model.load_state_dict(text_model_dict)
773
+
774
+ # load mapper
775
+ keys_mapper = {
776
+ k[len("cond_stage_model.mapper.res") :]: v
777
+ for k, v in checkpoint.items()
778
+ if k.startswith("cond_stage_model.mapper")
779
+ }
780
+
781
+ MAPPING = {
782
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
783
+ "attn.c_proj": ["attn1.to_out.0"],
784
+ "ln_1": ["norm1"],
785
+ "ln_2": ["norm3"],
786
+ "mlp.c_fc": ["ff.net.0.proj"],
787
+ "mlp.c_proj": ["ff.net.2"],
788
+ }
789
+
790
+ mapped_weights = {}
791
+ for key, value in keys_mapper.items():
792
+ prefix = key[: len("blocks.i")]
793
+ suffix = key.split(prefix)[-1].split(".")[-1]
794
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
795
+ mapped_names = MAPPING[name]
796
+
797
+ num_splits = len(mapped_names)
798
+ for i, mapped_name in enumerate(mapped_names):
799
+ new_name = ".".join([prefix, mapped_name, suffix])
800
+ shape = value.shape[0] // num_splits
801
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
802
+
803
+ model.mapper.load_state_dict(mapped_weights)
804
+
805
+ # load final layer norm
806
+ model.final_layer_norm.load_state_dict(
807
+ {
808
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
809
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
810
+ }
811
+ )
812
+
813
+ # load final proj
814
+ model.proj_out.load_state_dict(
815
+ {
816
+ "bias": checkpoint["proj_out.bias"],
817
+ "weight": checkpoint["proj_out.weight"],
818
+ }
819
+ )
820
+
821
+ # load uncond vector
822
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
823
+ return model
824
+
825
+
826
+ def convert_open_clip_checkpoint(checkpoint):
827
+ text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
828
+
829
+ keys = list(checkpoint.keys())
830
+
831
+ text_model_dict = {}
832
+
833
+ if "cond_stage_model.model.text_projection" in checkpoint:
834
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
835
+ else:
836
+ d_model = 1024
837
+
838
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
839
+
840
+ for key in keys:
841
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
842
+ continue
843
+ if key in textenc_conversion_map:
844
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
845
+ if key.startswith("cond_stage_model.model.transformer."):
846
+ new_key = key[len("cond_stage_model.model.transformer.") :]
847
+ if new_key.endswith(".in_proj_weight"):
848
+ new_key = new_key[: -len(".in_proj_weight")]
849
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
850
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
851
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
852
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
853
+ elif new_key.endswith(".in_proj_bias"):
854
+ new_key = new_key[: -len(".in_proj_bias")]
855
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
856
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
857
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
858
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
859
+ else:
860
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
861
+
862
+ text_model_dict[new_key] = checkpoint[key]
863
+
864
+ text_model.load_state_dict(text_model_dict)
865
+
866
+ return text_model
867
+
868
+
869
+ def stable_unclip_image_encoder(original_config):
870
+ """
871
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
872
+
873
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
874
+ encoders.
875
+ """
876
+
877
+ image_embedder_config = original_config.model.params.embedder_config
878
+
879
+ sd_clip_image_embedder_class = image_embedder_config.target
880
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
881
+
882
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
883
+ clip_model_name = image_embedder_config.params.model
884
+
885
+ if clip_model_name == "ViT-L/14":
886
+ feature_extractor = CLIPImageProcessor()
887
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
888
+ else:
889
+ raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
890
+
891
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
892
+ feature_extractor = CLIPImageProcessor()
893
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
894
+ else:
895
+ raise NotImplementedError(
896
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
897
+ )
898
+
899
+ return feature_extractor, image_encoder
900
+
901
+
902
+ def stable_unclip_image_noising_components(
903
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
904
+ ):
905
+ """
906
+ Returns the noising components for the img2img and txt2img unclip pipelines.
907
+
908
+ Converts the stability noise augmentor into
909
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
910
+ 2. a `DDPMScheduler` for holding the noise schedule
911
+
912
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
913
+ """
914
+ noise_aug_config = original_config.model.params.noise_aug_config
915
+ noise_aug_class = noise_aug_config.target
916
+ noise_aug_class = noise_aug_class.split(".")[-1]
917
+
918
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
919
+ noise_aug_config = noise_aug_config.params
920
+ embedding_dim = noise_aug_config.timestep_dim
921
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
922
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
923
+
924
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
925
+ image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
926
+
927
+ if "clip_stats_path" in noise_aug_config:
928
+ if clip_stats_path is None:
929
+ raise ValueError("This stable unclip config requires a `clip_stats_path`")
930
+
931
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
932
+ clip_mean = clip_mean[None, :]
933
+ clip_std = clip_std[None, :]
934
+
935
+ clip_stats_state_dict = {
936
+ "mean": clip_mean,
937
+ "std": clip_std,
938
+ }
939
+
940
+ image_normalizer.load_state_dict(clip_stats_state_dict)
941
+ else:
942
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
943
+
944
+ return image_normalizer, image_noising_scheduler
945
+
946
+
947
+ def convert_controlnet_checkpoint(
948
+ checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
949
+ ):
950
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
951
+ ctrlnet_config["upcast_attention"] = upcast_attention
952
+
953
+ ctrlnet_config.pop("sample_size")
954
+
955
+ controlnet_model = ControlNetModel(**ctrlnet_config)
956
+
957
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
958
+ checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
959
+ )
960
+
961
+ controlnet_model.load_state_dict(converted_ctrl_checkpoint)
962
+
963
+ return controlnet_model
musev/utils/convert_lora_safetensor_to_diffusers.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Conversion script for the LoRA's safetensors checkpoints. """
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+
23
+ from diffusers import StableDiffusionPipeline
24
+ import pdb
25
+
26
+
27
+
28
+ def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
29
+ # directly update weight in diffusers model
30
+ for key in state_dict:
31
+ # only process lora down key
32
+ if "up." in key: continue
33
+
34
+ up_key = key.replace(".down.", ".up.")
35
+ model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
36
+ model_key = model_key.replace("to_out.", "to_out.0.")
37
+ layer_infos = model_key.split(".")[:-1]
38
+
39
+ curr_layer = pipeline.unet
40
+ while len(layer_infos) > 0:
41
+ temp_name = layer_infos.pop(0)
42
+ curr_layer = curr_layer.__getattr__(temp_name)
43
+
44
+ weight_down = state_dict[key]
45
+ weight_up = state_dict[up_key]
46
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
47
+
48
+ return pipeline
49
+
50
+
51
+
52
+ def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
53
+ # load base model
54
+ # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
55
+
56
+ # load LoRA weight from .safetensors
57
+ # state_dict = load_file(checkpoint_path)
58
+
59
+ visited = []
60
+
61
+ # directly update weight in diffusers model
62
+ for key in state_dict:
63
+ # it is suggested to print out the key, it usually will be something like below
64
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
65
+
66
+ # as we have set the alpha beforehand, so just skip
67
+ if ".alpha" in key or key in visited:
68
+ continue
69
+
70
+ if "text" in key:
71
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
72
+ curr_layer = pipeline.text_encoder
73
+ else:
74
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
75
+ curr_layer = pipeline.unet
76
+
77
+ # find the target layer
78
+ temp_name = layer_infos.pop(0)
79
+ while len(layer_infos) > -1:
80
+ try:
81
+ curr_layer = curr_layer.__getattr__(temp_name)
82
+ if len(layer_infos) > 0:
83
+ temp_name = layer_infos.pop(0)
84
+ elif len(layer_infos) == 0:
85
+ break
86
+ except Exception:
87
+ if len(temp_name) > 0:
88
+ temp_name += "_" + layer_infos.pop(0)
89
+ else:
90
+ temp_name = layer_infos.pop(0)
91
+
92
+ pair_keys = []
93
+ if "lora_down" in key:
94
+ pair_keys.append(key.replace("lora_down", "lora_up"))
95
+ pair_keys.append(key)
96
+ else:
97
+ pair_keys.append(key)
98
+ pair_keys.append(key.replace("lora_up", "lora_down"))
99
+
100
+ # update weight
101
+ if len(state_dict[pair_keys[0]].shape) == 4:
102
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
103
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
104
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
105
+ else:
106
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
107
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
108
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
109
+
110
+ # update visited list
111
+ for item in pair_keys:
112
+ visited.append(item)
113
+
114
+ return pipeline
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+
120
+ parser.add_argument(
121
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
122
+ )
123
+ parser.add_argument(
124
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
125
+ )
126
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
127
+ parser.add_argument(
128
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
129
+ )
130
+ parser.add_argument(
131
+ "--lora_prefix_text_encoder",
132
+ default="lora_te",
133
+ type=str,
134
+ help="The prefix of text encoder weight in safetensors",
135
+ )
136
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
137
+ parser.add_argument(
138
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
139
+ )
140
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
141
+
142
+ args = parser.parse_args()
143
+
144
+ base_model_path = args.base_model_path
145
+ checkpoint_path = args.checkpoint_path
146
+ dump_path = args.dump_path
147
+ lora_prefix_unet = args.lora_prefix_unet
148
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
149
+ alpha = args.alpha
150
+
151
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
152
+
153
+ pipe = pipe.to(args.device)
154
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)