Spaces:
Runtime error
Runtime error
jmanhype
commited on
Commit
·
0a72c84
0
Parent(s):
Initial Space setup
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +1 -0
- README.md +496 -0
- musev/__init__.py +9 -0
- musev/auto_prompt/__init__.py +0 -0
- musev/auto_prompt/attributes/__init__.py +8 -0
- musev/auto_prompt/attributes/attr2template.py +127 -0
- musev/auto_prompt/attributes/attributes.py +227 -0
- musev/auto_prompt/attributes/human.py +424 -0
- musev/auto_prompt/attributes/render.py +33 -0
- musev/auto_prompt/attributes/style.py +12 -0
- musev/auto_prompt/human.py +40 -0
- musev/auto_prompt/load_template.py +37 -0
- musev/auto_prompt/util.py +25 -0
- musev/data/__init__.py +0 -0
- musev/data/data_util.py +681 -0
- musev/logging.conf +32 -0
- musev/models/__init__.py +3 -0
- musev/models/attention.py +431 -0
- musev/models/attention_processor.py +750 -0
- musev/models/controlnet.py +399 -0
- musev/models/embeddings.py +87 -0
- musev/models/facein_loader.py +120 -0
- musev/models/ip_adapter_face_loader.py +179 -0
- musev/models/ip_adapter_loader.py +340 -0
- musev/models/referencenet.py +1216 -0
- musev/models/referencenet_loader.py +124 -0
- musev/models/resnet.py +135 -0
- musev/models/super_model.py +253 -0
- musev/models/temporal_transformer.py +308 -0
- musev/models/text_model.py +40 -0
- musev/models/transformer_2d.py +445 -0
- musev/models/unet_2d_blocks.py +1537 -0
- musev/models/unet_3d_blocks.py +1413 -0
- musev/models/unet_3d_condition.py +1740 -0
- musev/models/unet_loader.py +273 -0
- musev/pipelines/__init__.py +0 -0
- musev/pipelines/context.py +149 -0
- musev/pipelines/pipeline_controlnet.py +0 -0
- musev/pipelines/pipeline_controlnet_predictor.py +1290 -0
- musev/schedulers/__init__.py +6 -0
- musev/schedulers/scheduling_ddim.py +302 -0
- musev/schedulers/scheduling_ddpm.py +262 -0
- musev/schedulers/scheduling_dpmsolver_multistep.py +815 -0
- musev/schedulers/scheduling_euler_ancestral_discrete.py +356 -0
- musev/schedulers/scheduling_euler_discrete.py +293 -0
- musev/schedulers/scheduling_lcm.py +312 -0
- musev/utils/__init__.py +0 -0
- musev/utils/attention_util.py +74 -0
- musev/utils/convert_from_ckpt.py +963 -0
- musev/utils/convert_lora_safetensor_to_diffusers.py +154 -0
Dockerfile
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
README.md
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MuseV
|
3 |
+
emoji: 🎨
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: docker
|
7 |
+
app_port: 7860
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
|
11 |
+
# MuseV [English](README.md) [中文](README-zh.md)
|
12 |
+
|
13 |
+
<font size=5>MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising
|
14 |
+
</br>
|
15 |
+
Zhiqiang Xia <sup>\*</sup>,
|
16 |
+
Zhaokang Chen<sup>\*</sup>,
|
17 |
+
Bin Wu<sup>†</sup>,
|
18 |
+
Chao Li,
|
19 |
+
Kwok-Wai Hung,
|
20 |
+
Chao Zhan,
|
21 |
+
Yingjie He,
|
22 |
+
Wenjiang Zhou
|
23 |
+
(<sup>*</sup>co-first author, <sup>†</sup>Corresponding Author, [email protected])
|
24 |
+
</font>
|
25 |
+
|
26 |
+
**[github](https://github.com/TMElyralab/MuseV)** **[huggingface](https://huggingface.co/TMElyralab/MuseV)** **[HuggingfaceSpace](https://huggingface.co/spaces/AnchorFake/MuseVDemo)** **[project](https://tmelyralab.github.io/MuseV_Page/)** **Technical report (comming soon)**
|
27 |
+
|
28 |
+
|
29 |
+
We have setup **the world simulator vision since March 2023, believing diffusion models can simulate the world**. `MuseV` was a milestone achieved around **July 2023**. Amazed by the progress of Sora, we decided to opensource `MuseV`, hopefully it will benefit the community. Next we will move on to the promising diffusion+transformer scheme.
|
30 |
+
|
31 |
+
|
32 |
+
Update: We have released <a href="https://github.com/TMElyralab/MuseTalk" style="font-size:24px; color:red;">MuseTalk</a>, a real-time high quality lip sync model, which can be applied with MuseV as a complete virtual human generation solution.
|
33 |
+
|
34 |
+
# Overview
|
35 |
+
`MuseV` is a diffusion-based virtual human video generation framework, which
|
36 |
+
1. supports **infinite length** generation using a novel **Visual Conditioned Parallel Denoising scheme**.
|
37 |
+
2. checkpoint available for virtual human video generation trained on human dataset.
|
38 |
+
3. supports Image2Video, Text2Image2Video, Video2Video.
|
39 |
+
4. compatible with the **Stable Diffusion ecosystem**, including `base_model`, `lora`, `controlnet`, etc.
|
40 |
+
5. supports multi reference image technology, including `IPAdapter`, `ReferenceOnly`, `ReferenceNet`, `IPAdapterFaceID`.
|
41 |
+
6. training codes (comming very soon).
|
42 |
+
|
43 |
+
# Important bug fixes
|
44 |
+
1. `musev_referencenet_pose`: model_name of `unet`, `ip_adapter` of Command is not correct, please use `musev_referencenet_pose` instead of `musev_referencenet`.
|
45 |
+
|
46 |
+
# News
|
47 |
+
- [03/27/2024] release `MuseV` project and trained model `musev`, `muse_referencenet`.
|
48 |
+
- [03/30/2024] add huggingface space gradio to generate video in gui
|
49 |
+
|
50 |
+
## Model
|
51 |
+
### Overview of model structure
|
52 |
+
![model_structure](./data/models/musev_structure.png)
|
53 |
+
### Parallel denoising
|
54 |
+
![parallel_denoise](./data//models/parallel_denoise.png)
|
55 |
+
|
56 |
+
## Cases
|
57 |
+
All frames were generated directly from text2video model, without any post process.
|
58 |
+
MoreCase is in **[project](https://tmelyralab.github.io/MuseV_Page/)**, including **1-2 minute video**.
|
59 |
+
|
60 |
+
<!-- # TODO: // use youtu video link? -->
|
61 |
+
Examples bellow can be accessed at `configs/tasks/example.yaml`
|
62 |
+
|
63 |
+
|
64 |
+
### Text/Image2Video
|
65 |
+
|
66 |
+
#### Human
|
67 |
+
|
68 |
+
<table class="center">
|
69 |
+
<tr style="font-weight: bolder;text-align:center;">
|
70 |
+
<td width="50%">image</td>
|
71 |
+
<td width="45%">video </td>
|
72 |
+
<td width="5%">prompt</td>
|
73 |
+
</tr>
|
74 |
+
|
75 |
+
<tr>
|
76 |
+
<td>
|
77 |
+
<img src=./data/images/yongen.jpeg width="400">
|
78 |
+
</td>
|
79 |
+
<td >
|
80 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/732cf1fd-25e7-494e-b462-969c9425d277" width="100" controls preload></video>
|
81 |
+
</td>
|
82 |
+
<td>(masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3)
|
83 |
+
</td>
|
84 |
+
</tr>
|
85 |
+
|
86 |
+
<tr>
|
87 |
+
<td>
|
88 |
+
<img src=./data/images/seaside4.jpeg width="400">
|
89 |
+
</td>
|
90 |
+
<td>
|
91 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/9b75a46c-f4e6-45ef-ad02-05729f091c8f" width="100" controls preload></video>
|
92 |
+
</td>
|
93 |
+
<td>
|
94 |
+
(masterpiece, best quality, highres:1), peaceful beautiful sea scene
|
95 |
+
</td>
|
96 |
+
</tr>
|
97 |
+
<tr>
|
98 |
+
<td>
|
99 |
+
<img src=./data/images/seaside_girl.jpeg width="400">
|
100 |
+
</td>
|
101 |
+
<td>
|
102 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/d0f3b401-09bf-4018-81c3-569ec24a4de9" width="100" controls preload></video>
|
103 |
+
</td>
|
104 |
+
<td>
|
105 |
+
(masterpiece, best quality, highres:1), peaceful beautiful sea scene
|
106 |
+
</td>
|
107 |
+
</tr>
|
108 |
+
<!-- guitar -->
|
109 |
+
<tr>
|
110 |
+
<td>
|
111 |
+
<img src=./data/images/boy_play_guitar.jpeg width="400">
|
112 |
+
</td>
|
113 |
+
<td>
|
114 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/61bf955e-7161-44c8-a498-8811c4f4eb4f" width="100" controls preload></video>
|
115 |
+
</td>
|
116 |
+
<td>
|
117 |
+
(masterpiece, best quality, highres:1), playing guitar
|
118 |
+
</td>
|
119 |
+
</tr>
|
120 |
+
<tr>
|
121 |
+
<td>
|
122 |
+
<img src=./data/images/girl_play_guitar2.jpeg width="400">
|
123 |
+
</td>
|
124 |
+
<td>
|
125 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/40982aa7-9f6a-4e44-8ef6-3f185d284e6a" width="100" controls preload></video>
|
126 |
+
</td>
|
127 |
+
<td>
|
128 |
+
(masterpiece, best quality, highres:1), playing guitar
|
129 |
+
</td>
|
130 |
+
</tr>
|
131 |
+
<!-- famous people -->
|
132 |
+
<tr>
|
133 |
+
<td>
|
134 |
+
<img src=./data/images/dufu.jpeg width="400">
|
135 |
+
</td>
|
136 |
+
<td>
|
137 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/28294baa-b996-420f-b1fb-046542adf87d" width="100" controls preload></video>
|
138 |
+
</td>
|
139 |
+
<td>
|
140 |
+
(masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3),Chinese ink painting style
|
141 |
+
</td>
|
142 |
+
</tr>
|
143 |
+
|
144 |
+
<tr>
|
145 |
+
<td>
|
146 |
+
<img src=./data/images/Mona_Lisa.jpg width="400">
|
147 |
+
</td>
|
148 |
+
<td>
|
149 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/1ce11da6-14c6-4dcd-b7f9-7a5f060d71fb" width="100" controls preload></video>
|
150 |
+
</td>
|
151 |
+
<td>
|
152 |
+
(masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face,
|
153 |
+
soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3)
|
154 |
+
</td>
|
155 |
+
</tr>
|
156 |
+
</table >
|
157 |
+
|
158 |
+
#### Scene
|
159 |
+
|
160 |
+
<table class="center">
|
161 |
+
<tr style="font-weight: bolder;text-align:center;">
|
162 |
+
<td width="35%">image</td>
|
163 |
+
<td width="50%">video</td>
|
164 |
+
<td width="15%">prompt</td>
|
165 |
+
</tr>
|
166 |
+
|
167 |
+
<tr>
|
168 |
+
<td>
|
169 |
+
<img src=./data/images/waterfall4.jpeg width="400">
|
170 |
+
</td>
|
171 |
+
<td>
|
172 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/852daeb6-6b58-4931-81f9-0dddfa1b4ea5" width="100" controls preload></video>
|
173 |
+
</td>
|
174 |
+
<td>
|
175 |
+
(masterpiece, best quality, highres:1), peaceful beautiful waterfall, an
|
176 |
+
endless waterfall
|
177 |
+
</td>
|
178 |
+
</tr>
|
179 |
+
|
180 |
+
<tr>
|
181 |
+
<td>
|
182 |
+
<img src=./data/images/seaside2.jpeg width="400">
|
183 |
+
</td>
|
184 |
+
<td>
|
185 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/4a4d527a-6203-411f-afe9-31c992d26816" width="100" controls preload></video>
|
186 |
+
</td>
|
187 |
+
<td>(masterpiece, best quality, highres:1), peaceful beautiful sea scene
|
188 |
+
</td>
|
189 |
+
</tr>
|
190 |
+
</table >
|
191 |
+
|
192 |
+
### VideoMiddle2Video
|
193 |
+
|
194 |
+
**pose2video**
|
195 |
+
In `duffy` mode, pose of the vision condition frame is not aligned with the first frame of control video. `posealign` will solve the problem.
|
196 |
+
|
197 |
+
<table class="center">
|
198 |
+
<tr style="font-weight: bolder;text-align:center;">
|
199 |
+
<td width="25%">image</td>
|
200 |
+
<td width="65%">video</td>
|
201 |
+
<td width="10%">prompt</td>
|
202 |
+
</tr>
|
203 |
+
<tr>
|
204 |
+
<td>
|
205 |
+
<img src=./data/images/spark_girl.png width="200">
|
206 |
+
<img src=./data/images/cyber_girl.png width="200">
|
207 |
+
</td>
|
208 |
+
<td>
|
209 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/484cc69d-c316-4464-a55b-3df929780a8e" width="400" controls preload></video>
|
210 |
+
</td>
|
211 |
+
<td>
|
212 |
+
(masterpiece, best quality, highres:1) , a girl is dancing, animation
|
213 |
+
</td>
|
214 |
+
</tr>
|
215 |
+
<tr>
|
216 |
+
<td>
|
217 |
+
<img src=./data/images/duffy.png width="400">
|
218 |
+
</td>
|
219 |
+
<td>
|
220 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/c44682e6-aafc-4730-8fc1-72825c1bacf2" width="400" controls preload></video>
|
221 |
+
</td>
|
222 |
+
<td>
|
223 |
+
(masterpiece, best quality, highres:1), is dancing, animation
|
224 |
+
</td>
|
225 |
+
</tr>
|
226 |
+
</table >
|
227 |
+
|
228 |
+
### MuseTalk
|
229 |
+
The character of talk, `Sun Xinying` is a supermodel KOL. You can follow her on [douyin](https://www.douyin.com/user/MS4wLjABAAAAWDThbMPN_6Xmm_JgXexbOii1K-httbu2APdG8DvDyM8).
|
230 |
+
|
231 |
+
<table class="center">
|
232 |
+
<tr style="font-weight: bolder;">
|
233 |
+
<td width="35%">name</td>
|
234 |
+
<td width="50%">video</td>
|
235 |
+
</tr>
|
236 |
+
|
237 |
+
<tr>
|
238 |
+
<td>
|
239 |
+
talk
|
240 |
+
</td>
|
241 |
+
<td>
|
242 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/951188d1-4731-4e7f-bf40-03cacba17f2f" width="100" controls preload></video>
|
243 |
+
</td>
|
244 |
+
<tr>
|
245 |
+
<td>
|
246 |
+
sing
|
247 |
+
</td>
|
248 |
+
<td>
|
249 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/50b8ffab-9307-4836-99e5-947e6ce7d112" width="100" controls preload></video>
|
250 |
+
</td>
|
251 |
+
</tr>
|
252 |
+
</table >
|
253 |
+
|
254 |
+
|
255 |
+
# TODO:
|
256 |
+
- [ ] technical report (comming soon).
|
257 |
+
- [ ] training codes.
|
258 |
+
- [ ] release pretrained unet model, which is trained with controlnet、referencenet、IPAdapter, which is better on pose2video.
|
259 |
+
- [ ] support diffusion transformer generation framework.
|
260 |
+
- [ ] release `posealign` module
|
261 |
+
|
262 |
+
# Quickstart
|
263 |
+
Prepare python environment and install extra package like `diffusers`, `controlnet_aux`, `mmcm`.
|
264 |
+
|
265 |
+
## Third party integration
|
266 |
+
Thanks for the third-party integration, which makes installation and use more convenient for everyone.
|
267 |
+
We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
|
268 |
+
|
269 |
+
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseV)
|
270 |
+
### [One click integration package in windows](https://www.bilibili.com/video/BV1ux4y1v7pF/?vd_source=fe03b064abab17b79e22a692551405c3)
|
271 |
+
netdisk:https://www.123pan.com/s/Pf5Yjv-Bb9W3.html
|
272 |
+
|
273 |
+
code: glut
|
274 |
+
|
275 |
+
## Prepare environment
|
276 |
+
You are recommended to use `docker` primarily to prepare python environment.
|
277 |
+
### prepare python env
|
278 |
+
**Attention**: we only test with docker, there are maybe trouble with conda, or requirement. We will try to fix it. Use `docker` Please.
|
279 |
+
|
280 |
+
#### Method 1: docker
|
281 |
+
1. pull docker image
|
282 |
+
```bash
|
283 |
+
docker pull anchorxia/musev:latest
|
284 |
+
```
|
285 |
+
2. run docker
|
286 |
+
```bash
|
287 |
+
docker run --gpus all -it --entrypoint /bin/bash anchorxia/musev:latest
|
288 |
+
```
|
289 |
+
The default conda env is `musev`.
|
290 |
+
|
291 |
+
#### Method 2: conda
|
292 |
+
create conda environment from environment.yaml
|
293 |
+
```
|
294 |
+
conda env create --name musev --file ./environment.yml
|
295 |
+
```
|
296 |
+
#### Method 3: pip requirements
|
297 |
+
```
|
298 |
+
pip install -r requirements.txt
|
299 |
+
```
|
300 |
+
#### Prepare mmlab package
|
301 |
+
if not use docker, should install mmlab package additionally.
|
302 |
+
```bash
|
303 |
+
pip install --no-cache-dir -U openmim
|
304 |
+
mim install mmengine
|
305 |
+
mim install "mmcv>=2.0.1"
|
306 |
+
mim install "mmdet>=3.1.0"
|
307 |
+
mim install "mmpose>=1.1.0"
|
308 |
+
```
|
309 |
+
|
310 |
+
### Prepare custom package / modified package
|
311 |
+
#### clone
|
312 |
+
```bash
|
313 |
+
git clone --recursive https://github.com/TMElyralab/MuseV.git
|
314 |
+
```
|
315 |
+
#### prepare PYTHONPATH
|
316 |
+
```bash
|
317 |
+
current_dir=$(pwd)
|
318 |
+
export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV
|
319 |
+
export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/MMCM
|
320 |
+
export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/diffusers/src
|
321 |
+
export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/controlnet_aux/src
|
322 |
+
cd MuseV
|
323 |
+
```
|
324 |
+
|
325 |
+
1. `MMCM`: multi media, cross modal process package。
|
326 |
+
1. `diffusers`: modified diffusers package based on [diffusers](https://github.com/huggingface/diffusers)
|
327 |
+
1. `controlnet_aux`: modified based on [controlnet_aux](https://github.com/TMElyralab/controlnet_aux)
|
328 |
+
|
329 |
+
|
330 |
+
## Download models
|
331 |
+
```bash
|
332 |
+
git clone https://huggingface.co/TMElyralab/MuseV ./checkpoints
|
333 |
+
```
|
334 |
+
- `motion`: text2video model, trained on tiny `ucf101` and tiny `webvid` dataset, approximately 60K videos text pairs. GPU memory consumption testing on `resolution`$=512*512$, `time_size=12`.
|
335 |
+
- `musev/unet`: only has and train `unet` motion module. `GPU memory consumption` $\approx 8G$.
|
336 |
+
- `musev_referencenet`: train `unet` module, `referencenet`, `IPAdapter`. `GPU memory consumption` $\approx 12G$.
|
337 |
+
- `unet`: `motion` module, which has `to_k`, `to_v` in `Attention` layer refer to `IPAdapter`
|
338 |
+
- `referencenet`: similar to `AnimateAnyone`
|
339 |
+
- `ip_adapter_image_proj.bin`: images clip emb project layer, refer to `IPAdapter`
|
340 |
+
- `musev_referencenet_pose`: based on `musev_referencenet`, fix `referencenet`and `controlnet_pose`, train `unet motion` and `IPAdapter`. `GPU memory consumption` $\approx 12G$
|
341 |
+
- `t2i/sd1.5`: text2image model, parameter are frozen when training motion module. Different `t2i` base_model has a significant impact.could be replaced with other t2i base.
|
342 |
+
- `majicmixRealv6Fp16`: example, download from [majicmixRealv6Fp16](https://civitai.com/models/43331?modelVersionId=94640)
|
343 |
+
- `fantasticmix_v10`: example, download from [fantasticmix_v10](https://civitai.com/models/22402?modelVersionId=26744)
|
344 |
+
- `IP-Adapter/models`: download from [IPAdapter](https://huggingface.co/h94/IP-Adapter/tree/main)
|
345 |
+
- `image_encoder`: vision clip model.
|
346 |
+
- `ip-adapter_sd15.bin`: original IPAdapter model checkpoint.
|
347 |
+
- `ip-adapter-faceid_sd15.bin`: original IPAdapter model checkpoint.
|
348 |
+
|
349 |
+
## Inference
|
350 |
+
|
351 |
+
### Prepare model_path
|
352 |
+
Skip this step when run example task with example inference command.
|
353 |
+
Set model path and abbreviation in config, to use abbreviation in inference script.
|
354 |
+
- T2I SD:ref to `musev/configs/model/T2I_all_model.py`
|
355 |
+
- Motion Unet: refer to `musev/configs/model/motion_model.py`
|
356 |
+
- Task: refer to `musev/configs/tasks/example.yaml`
|
357 |
+
|
358 |
+
### musev_referencenet
|
359 |
+
#### text2video
|
360 |
+
```bash
|
361 |
+
python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --time_size 12 --fps 12
|
362 |
+
```
|
363 |
+
**common parameters**:
|
364 |
+
- `test_data_path`: task_path in yaml extention
|
365 |
+
- `target_datas`: sep is `,`, sample subtasks if `name` in `test_data_path` is in `target_datas`.
|
366 |
+
- `sd_model_cfg_path`: T2I sd models path, model config path or model path.
|
367 |
+
- `sd_model_name`: sd model name, which use to choose full model path in sd_model_cfg_path. multi model names with sep =`,`, or `all`
|
368 |
+
- `unet_model_cfg_path`: motion unet model config path or model path。
|
369 |
+
- `unet_model_name`: unet model name, use to get model path in `unet_model_cfg_path`, and init unet class instance in `musev/models/unet_loader.py`. multi model names with sep=`,`, or `all`. If `unet_model_cfg_path` is model path, `unet_name` must be supported in `musev/models/unet_loader.py`
|
370 |
+
- `time_size`: num_frames per diffusion denoise generation。default=`12`.
|
371 |
+
- `n_batch`: generation numbers of shot, $total\_frames=n\_batch * time\_size + n\_viscond$, default=`1`。
|
372 |
+
- `context_frames`: context_frames num. If `time_size` > `context_frame`,`time_size` window is split into many sub-windows for parallel denoising"。 default=`12`。
|
373 |
+
|
374 |
+
**To generate long videos**, there two ways:
|
375 |
+
1. `visual conditioned parallel denoise`: set `n_batch=1`, `time_size` = all frames you want.
|
376 |
+
1. `traditional end-to-end`: set `time_size` = `context_frames` = frames of a shot (`12`), `context_overlap` = 0;
|
377 |
+
|
378 |
+
|
379 |
+
**model parameters**:
|
380 |
+
supports `referencenet`, `IPAdapter`, `IPAdapterFaceID`, `Facein`.
|
381 |
+
- referencenet_model_name: `referencenet` model name.
|
382 |
+
- ImageClipVisionFeatureExtractor: `ImageEmbExtractor` name, extractor vision clip emb used in `IPAdapter`.
|
383 |
+
- vision_clip_model_path: `ImageClipVisionFeatureExtractor` model path.
|
384 |
+
- ip_adapter_model_name: from `IPAdapter`, it's `ImagePromptEmbProj`, used with `ImageEmbExtractor`。
|
385 |
+
- ip_adapter_face_model_name: `IPAdapterFaceID`, from `IPAdapter` to keep faceid,should set `face_image_path`。
|
386 |
+
|
387 |
+
**Some parameters that affect the motion range and generation results**:
|
388 |
+
- `video_guidance_scale`: Similar to text2image, control influence between cond and uncond,default=`3.5`
|
389 |
+
- `use_condition_image`: Whether to use the given first frame for video generation, if not generate vision condition frames first. Default=`True`.
|
390 |
+
- `redraw_condition_image`: Whether to redraw the given first frame image.
|
391 |
+
- `video_negative_prompt`: Abbreviation of full `negative_prompt` in config path. default=`V2`.
|
392 |
+
|
393 |
+
|
394 |
+
#### video2video
|
395 |
+
`t2i` base_model has a significant impact. In this case, `fantasticmix_v10` performs better than `majicmixRealv6Fp16`.
|
396 |
+
```bash
|
397 |
+
python scripts/inference/video2video.py --sd_model_name fantasticmix_v10 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
|
398 |
+
```
|
399 |
+
**import parameters**
|
400 |
+
|
401 |
+
Most of the parameters are same as `musev_text2video`. Special parameters of `video2video` are:
|
402 |
+
1. need to set `video_path` as reference video in `test_data`. Now reference video supports `rgb video` and `controlnet_middle_video`。
|
403 |
+
- `which2video`: whether `rgb` video influences initial noise, influence of `rgb` is stronger than of controlnet condition.
|
404 |
+
- `controlnet_name`:whether to use `controlnet condition`, such as `dwpose,depth`.
|
405 |
+
- `video_is_middle`: `video_path` is `rgb video` or `controlnet_middle_video`. Can be set for every `test_data` in test_data_path.
|
406 |
+
- `video_has_condition`: whether condtion_images is aligned with the first frame of video_path. If Not, exrtact condition of `condition_images` firstly generate, and then align with concatation. set in `test_data`。
|
407 |
+
|
408 |
+
all controlnet_names refer to [mmcm](https://github.com/TMElyralab/MMCM/blob/main/mmcm/vision/feature_extractor/controlnet.py#L513)
|
409 |
+
```python
|
410 |
+
['pose', 'pose_body', 'pose_hand', 'pose_face', 'pose_hand_body', 'pose_hand_face', 'dwpose', 'dwpose_face', 'dwpose_hand', 'dwpose_body', 'dwpose_body_hand', 'canny', 'tile', 'hed', 'hed_scribble', 'depth', 'pidi', 'normal_bae', 'lineart', 'lineart_anime', 'zoe', 'sam', 'mobile_sam', 'leres', 'content', 'face_detector']
|
411 |
+
```
|
412 |
+
|
413 |
+
### musev_referencenet_pose
|
414 |
+
Only used for `pose2video`
|
415 |
+
train based on `musev_referencenet`, fix `referencenet`, `pose-controlnet`, and `T2I`, train `motion` module and `IPAdapter`.
|
416 |
+
|
417 |
+
`t2i` base_model has a significant impact. In this case, `fantasticmix_v10` performs better than `majicmixRealv6Fp16`.
|
418 |
+
|
419 |
+
```bash
|
420 |
+
python scripts/inference/video2video.py --sd_model_name fantasticmix_v10 --unet_model_name musev_referencenet_pose --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet_pose -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
|
421 |
+
```
|
422 |
+
|
423 |
+
### musev
|
424 |
+
Only has motion module, no referencenet, requiring less gpu memory.
|
425 |
+
#### text2video
|
426 |
+
```bash
|
427 |
+
python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --time_size 12 --fps 12
|
428 |
+
```
|
429 |
+
#### video2video
|
430 |
+
##### pose align
|
431 |
+
```bash
|
432 |
+
python ./pose_align/pose_align.py --max_frame 200 --vidfn ./data/source_video/dance.mp4 --imgfn_refer ./data/images/man.jpg --outfn_ref_img_pose ./data/pose_align_results/ref_img_pose.jpg --outfn_align_pose_video ./data/pose_align_results/align_pose_video.mp4 --outfn ./data/pose_align_results/align_demo.mp4
|
433 |
+
```
|
434 |
+
- `max_frame`: how many frames to align (count from the first frame)
|
435 |
+
- `vidfn`:real dance video in rgb
|
436 |
+
- `imgfn_refer`: refer image path
|
437 |
+
- `outfn_ref_img_pose`: output path of the pose of the refer img
|
438 |
+
- `outfn_align_pose_video`: output path of the aligned video of the refer img
|
439 |
+
- `outfn`: output path of the alignment visualization
|
440 |
+
|
441 |
+
|
442 |
+
|
443 |
+
https://github.com/TMElyralab/MuseV/assets/47803475/787d7193-ec69-43f4-a0e5-73986a808f51
|
444 |
+
|
445 |
+
|
446 |
+
|
447 |
+
|
448 |
+
then you can use the aligned pose `outfn_align_pose_video` for pose guided generation. You may need to modify the example in the config file `./configs/tasks/example.yaml`
|
449 |
+
##### generation
|
450 |
+
```bash
|
451 |
+
python scripts/inference/video2video.py --sd_model_name fantasticmix_v10 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
|
452 |
+
```
|
453 |
+
|
454 |
+
### Gradio demo
|
455 |
+
MuseV provides gradio script to generate a GUI in a local machine to generate video conveniently.
|
456 |
+
|
457 |
+
```bash
|
458 |
+
cd scripts/gradio
|
459 |
+
python app.py
|
460 |
+
```
|
461 |
+
|
462 |
+
|
463 |
+
# Acknowledgements
|
464 |
+
|
465 |
+
1. MuseV has referred much to [TuneAVideo](https://github.com/showlab/Tune-A-Video), [diffusers](https://github.com/huggingface/diffusers), [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone/tree/master/src/pipelines), [animatediff](https://github.com/guoyww/AnimateDiff), [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter), [AnimateAnyone](https://arxiv.org/abs/2311.17117), [VideoFusion](https://arxiv.org/abs/2303.08320), [insightface](https://github.com/deepinsight/insightface).
|
466 |
+
2. MuseV has been built on `ucf101` and `webvid` datasets.
|
467 |
+
|
468 |
+
Thanks for open-sourcing!
|
469 |
+
|
470 |
+
# Limitation
|
471 |
+
There are still many limitations, including
|
472 |
+
|
473 |
+
1. Lack of generalization ability. Some visual condition image perform well, some perform bad. Some t2i pretraied model perform well, some perform bad.
|
474 |
+
1. Limited types of video generation and limited motion range, partly because of limited types of training data. The released `MuseV` has been trained on approximately 60K human text-video pairs with resolution `512*320`. `MuseV` has greater motion range while lower video quality at lower resolution. `MuseV` tends to generate less motion range with high video quality. Trained on larger, higher resolution, higher quality text-video dataset may make `MuseV` better.
|
475 |
+
1. Watermarks may appear because of `webvid`. A cleaner dataset without watermarks may solve this issue.
|
476 |
+
1. Limited types of long video generation. Visual Conditioned Parallel Denoise can solve accumulated error of video generation, but the current method is only suitable for relatively fixed camera scenes.
|
477 |
+
1. Undertrained referencenet and IP-Adapter, beacause of limited time and limited resources.
|
478 |
+
1. Understructured code. `MuseV` supports rich and dynamic features, but with complex and unrefacted codes. It takes time to familiarize.
|
479 |
+
|
480 |
+
|
481 |
+
<!-- # Contribution 暂时不需要组织开源共建 -->
|
482 |
+
# Citation
|
483 |
+
```bib
|
484 |
+
@article{musev,
|
485 |
+
title={MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising},
|
486 |
+
author={Xia, Zhiqiang and Chen, Zhaokang and Wu, Bin and Li, Chao and Hung, Kwok-Wai and Zhan, Chao and He, Yingjie and Zhou, Wenjiang},
|
487 |
+
journal={arxiv},
|
488 |
+
year={2024}
|
489 |
+
}
|
490 |
+
```
|
491 |
+
# Disclaimer/License
|
492 |
+
1. `code`: The code of MuseV is released under the MIT License. There is no limitation for both academic and commercial usage.
|
493 |
+
1. `model`: The trained model are available for non-commercial research purposes only.
|
494 |
+
1. `other opensource model`: Other open-source models used must comply with their license, such as `insightface`, `IP-Adapter`, `ft-mse-vae`, etc.
|
495 |
+
1. The testdata are collected from internet, which are available for non-commercial research purposes only.
|
496 |
+
1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.
|
musev/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import logging.config
|
4 |
+
|
5 |
+
# 读取日志配置文件内容
|
6 |
+
logging.config.fileConfig(os.path.join(os.path.dirname(__file__), "logging.conf"))
|
7 |
+
|
8 |
+
# 创建一个日志器logger
|
9 |
+
logger = logging.getLogger("musev")
|
musev/auto_prompt/__init__.py
ADDED
File without changes
|
musev/auto_prompt/attributes/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ...utils.register import Register
|
2 |
+
|
3 |
+
AttrRegister = Register(registry_name="attributes")
|
4 |
+
|
5 |
+
# must import like bellow to ensure that each class is registered with AttrRegister:
|
6 |
+
from .human import *
|
7 |
+
from .render import *
|
8 |
+
from .style import *
|
musev/auto_prompt/attributes/attr2template.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""
|
2 |
+
中文
|
3 |
+
该模块将关键词字典转化为描述文本,生成完整的提词,从而降低对比实验成本、提升控制能力和效率。
|
4 |
+
提词(prompy)对比实验会需要控制关键属性发生变化、其他属性不变的文本对。当需要控制的属性变量发生较大变化时,靠人为复制粘贴进行完成文本撰写工作量会非常大。
|
5 |
+
该模块主要有三种类,分别是:
|
6 |
+
1. `BaseAttribute2Text`: 单属性文本转换类
|
7 |
+
2. `MultiAttr2Text` 多属性文本转化类,输出`List[Tuple[str, str]`。具体如何转换为文本在 `MultiAttr2PromptTemplate`中实现。
|
8 |
+
3. `MultiAttr2PromptTemplate`:先将2生成的多属性文本字典列表转化为完整的文本,然后再使用内置的模板`template`拼接。拼接后的文本作为实际模型输入的提词。
|
9 |
+
1. `template`字段若没有{},且有字符,则认为输入就是完整输入网络的`prompt`;
|
10 |
+
2. `template`字段若含有{key},则认为是带关键词的字符串目标,多个属性由`template`字符串中顺序完全决定。关键词内容由表格中相关列通过`attr2text`转化而来;
|
11 |
+
3. `template`字段有且只含有一个{},如`a portrait of {}`,则相关内容由 `PresetMultiAttr2PromptTemplate`中预定义好的`attrs`列表指定先后顺序;
|
12 |
+
|
13 |
+
English
|
14 |
+
This module converts a keyword dictionary into descriptive text, generating complete prompts to reduce the cost of comparison experiments, and improve control and efficiency.
|
15 |
+
|
16 |
+
Prompt-based comparison experiments require text pairs where the key attributes are controlled while other attributes remain constant. When the variable attributes to be controlled undergo significant changes, manually copying and pasting to write text can be very time-consuming.
|
17 |
+
|
18 |
+
This module mainly consists of three classes:
|
19 |
+
|
20 |
+
BaseAttribute2Text: A class for converting single attribute text.
|
21 |
+
MultiAttr2Text: A class for converting multi-attribute text, outputting List[Tuple[str, str]]. The specific implementation of how to convert to text is implemented in MultiAttr2PromptTemplate.
|
22 |
+
MultiAttr2PromptTemplate: First, the list of multi-attribute text dictionaries generated by 2 is converted into complete text, and then the built-in template template is used for concatenation. The concatenated text serves as the prompt for the actual model input.
|
23 |
+
If the template field does not contain {}, and there are characters, the input is considered the complete prompt for the network.
|
24 |
+
If the template field contains {key}, it is considered a string target with keywords, and the order of multiple attributes is completely determined by the template string. The keyword content is generated by attr2text from the relevant columns in the table.
|
25 |
+
If the template field contains only one {}, such as a portrait of {}, the relevant content is specified in the order defined by the attrs list predefined in PresetMultiAttr2PromptTemplate.
|
26 |
+
"""
|
27 |
+
|
28 |
+
from typing import List, Tuple, Union
|
29 |
+
|
30 |
+
from mmcm.utils.str_util import (
|
31 |
+
has_key_brace,
|
32 |
+
merge_near_same_char,
|
33 |
+
get_word_from_key_brace_string,
|
34 |
+
)
|
35 |
+
|
36 |
+
from .attributes import MultiAttr2Text, merge_multi_attrtext, AttriributeIsText
|
37 |
+
from . import AttrRegister
|
38 |
+
|
39 |
+
|
40 |
+
class MultiAttr2PromptTemplate(object):
|
41 |
+
"""
|
42 |
+
将多属性转化为模型输入文本的实际类
|
43 |
+
The actual class that converts multiple attributes into model input text is
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
template: str,
|
49 |
+
attr2text: MultiAttr2Text,
|
50 |
+
name: str,
|
51 |
+
) -> None:
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
template (str): 提词模板, prompt template.
|
55 |
+
如果`template`含有{key},则根据key来取值。 if the template field contains {key}, it means that the actual value for that part of the prompt will be determined by the corresponding key
|
56 |
+
如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。if the template field in MultiAttr2PromptTemplate contains only one {} placeholder, such as "a portrait of {}", the order of the attributes is determined by the attrs list predefined in PresetMultiAttr2PromptTemplate. The values of the attributes in the texts list are concatenated in the order specified by the attrs list.
|
57 |
+
attr2text (MultiAttr2Text): 多属性转换类。Class for converting multiple attributes into text prompt.
|
58 |
+
name (str): 该多属性文本模板类的名字,便于记忆. Class Instance name
|
59 |
+
"""
|
60 |
+
self.attr2text = attr2text
|
61 |
+
self.name = name
|
62 |
+
if template == "":
|
63 |
+
template = "{}"
|
64 |
+
self.template = template
|
65 |
+
self.template_has_key_brace = has_key_brace(template)
|
66 |
+
|
67 |
+
def __call__(self, attributes: dict) -> Union[str, List[str]]:
|
68 |
+
texts = self.attr2text(attributes)
|
69 |
+
if not isinstance(texts, list):
|
70 |
+
texts = [texts]
|
71 |
+
prompts = [merge_multi_attrtext(text, self.template) for text in texts]
|
72 |
+
prompts = [merge_near_same_char(prompt) for prompt in prompts]
|
73 |
+
if len(prompts) == 1:
|
74 |
+
prompts = prompts[0]
|
75 |
+
return prompts
|
76 |
+
|
77 |
+
|
78 |
+
class KeywordMultiAttr2PromptTemplate(MultiAttr2PromptTemplate):
|
79 |
+
def __init__(self, template: str, name: str = "keywords") -> None:
|
80 |
+
"""关键词模板属性2文本转化类
|
81 |
+
1. 获取关键词模板字符串中的关键词属性;
|
82 |
+
2. 从import * 存储在locals()中变量中获取对应的类;
|
83 |
+
3. 将集成了多属性转换类的`MultiAttr2Text`
|
84 |
+
Args:
|
85 |
+
template (str): 含有{key}的模板字符串
|
86 |
+
name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "keywords".
|
87 |
+
|
88 |
+
class for converting keyword template attributes to text
|
89 |
+
1. Get the keyword attributes in the keyword template string;
|
90 |
+
2. Get the corresponding class from the variables stored in locals() by import *;
|
91 |
+
3. The `MultiAttr2Text` integrated with multiple attribute conversion classes
|
92 |
+
Args:
|
93 |
+
template (str): template string containing {key}
|
94 |
+
name (str, optional): the name of the template string, no actual use. Defaults to "keywords".
|
95 |
+
"""
|
96 |
+
assert has_key_brace(
|
97 |
+
template
|
98 |
+
), "template should have key brace, but given {}".format(template)
|
99 |
+
keywords = get_word_from_key_brace_string(template)
|
100 |
+
funcs = []
|
101 |
+
for word in keywords:
|
102 |
+
if word in AttrRegister:
|
103 |
+
func = AttrRegister[word](name=word)
|
104 |
+
else:
|
105 |
+
func = AttriributeIsText(name=word)
|
106 |
+
funcs.append(func)
|
107 |
+
attr2text = MultiAttr2Text(funcs, name=name)
|
108 |
+
super().__init__(template, attr2text, name)
|
109 |
+
|
110 |
+
|
111 |
+
class OnlySpacePromptTemplate(MultiAttr2PromptTemplate):
|
112 |
+
def __init__(self, template: str, name: str = "space_prompt") -> None:
|
113 |
+
"""纯空模板,无论输入啥,都只返回空格字符串作为prompt。
|
114 |
+
Args:
|
115 |
+
template (str): 符合只输出空格字符串的模板,
|
116 |
+
name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "space_prompt".
|
117 |
+
|
118 |
+
Pure empty template, no matter what the input is, it will only return a space string as the prompt.
|
119 |
+
Args:
|
120 |
+
template (str): template that only outputs a space string,
|
121 |
+
name (str, optional): the name of the template string, no actual use. Defaults to "space_prompt".
|
122 |
+
"""
|
123 |
+
attr2text = None
|
124 |
+
super().__init__(template, attr2text, name)
|
125 |
+
|
126 |
+
def __call__(self, attributes: dict) -> Union[str, List[str]]:
|
127 |
+
return ""
|
musev/auto_prompt/attributes/attributes.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import List, Tuple, Dict
|
3 |
+
|
4 |
+
from mmcm.utils.str_util import has_key_brace
|
5 |
+
|
6 |
+
|
7 |
+
class BaseAttribute2Text(object):
|
8 |
+
"""
|
9 |
+
属性转化为文本的基类,该类作用就是输入属性,转化为描述文本。
|
10 |
+
Base class for converting attributes to text which converts attributes to prompt text.
|
11 |
+
"""
|
12 |
+
|
13 |
+
name = "base_attribute"
|
14 |
+
|
15 |
+
def __init__(self, name: str = None) -> None:
|
16 |
+
"""这里类实例初始化设置`name`参数,主要是为了便于一些没有提前实现、通过字符串参数实现的新属性。
|
17 |
+
Theses class instances are initialized with the `name` parameter to facilitate the implementation of new attributes that are not implemented in advance and are implemented through string parameters.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
name (str, optional): _description_. Defaults to None.
|
21 |
+
"""
|
22 |
+
if name is not None:
|
23 |
+
self.name = name
|
24 |
+
|
25 |
+
def __call__(self, attributes) -> str:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
|
29 |
+
class AttributeIsTextAndName(BaseAttribute2Text):
|
30 |
+
"""
|
31 |
+
属性文本转换功能类,将key和value拼接在一起作为文本.
|
32 |
+
class for converting attributes to text which concatenates the key and value together as text.
|
33 |
+
"""
|
34 |
+
|
35 |
+
name = "attribute_is_text_name"
|
36 |
+
|
37 |
+
def __call__(self, attributes) -> str:
|
38 |
+
if attributes == "" or attributes is None:
|
39 |
+
return ""
|
40 |
+
attributes = attributes.split(",")
|
41 |
+
text = ", ".join(
|
42 |
+
[
|
43 |
+
"{} {}".format(attr, self.name) if attr != "" else ""
|
44 |
+
for attr in attributes
|
45 |
+
]
|
46 |
+
)
|
47 |
+
return text
|
48 |
+
|
49 |
+
|
50 |
+
class AttriributeIsText(BaseAttribute2Text):
|
51 |
+
"""
|
52 |
+
属性文本转换功能类,将value作为文本.
|
53 |
+
class for converting attributes to text which only uses the value as text.
|
54 |
+
"""
|
55 |
+
|
56 |
+
name = "attribute_is_text"
|
57 |
+
|
58 |
+
def __call__(self, attributes: str) -> str:
|
59 |
+
if attributes == "" or attributes is None:
|
60 |
+
return ""
|
61 |
+
attributes = str(attributes)
|
62 |
+
attributes = attributes.split(",")
|
63 |
+
text = ", ".join(["{}".format(attr) for attr in attributes])
|
64 |
+
return text
|
65 |
+
|
66 |
+
|
67 |
+
class MultiAttr2Text(object):
|
68 |
+
"""将多属性组成的字典转换成完整的文本描述,目前采用简单的前后拼接方式,以`, `作为拼接符号
|
69 |
+
class for converting a dictionary of multiple attributes into a complete text description. Currently, a simple front and back splicing method is used, with `, ` as the splicing symbol.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
object (_type_): _description_
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, funcs: list, name) -> None:
|
76 |
+
"""
|
77 |
+
Args:
|
78 |
+
funcs (list): 继承`BaseAttribute2Text`并实现了`__call__`函数的类. Inherited `BaseAttribute2Text` and implemented the `__call__` function of the class.
|
79 |
+
name (_type_): 该多属性的一个名字,可通过该类方便了解对应相关属性都是关于啥的。 name of the multi-attribute, which can be used to easily understand what the corresponding related attributes are about.
|
80 |
+
"""
|
81 |
+
if not isinstance(funcs, list):
|
82 |
+
funcs = [funcs]
|
83 |
+
self.funcs = funcs
|
84 |
+
self.name = name
|
85 |
+
|
86 |
+
def __call__(
|
87 |
+
self, dct: dict, ignored_blank_str: bool = False
|
88 |
+
) -> List[Tuple[str, str]]:
|
89 |
+
"""
|
90 |
+
有时候一个属性可能会返回多个文本,如 style cartoon会返回宫崎骏和皮克斯两种风格,采用外积增殖成多个字典。
|
91 |
+
sometimes an attribute may return multiple texts, such as style cartoon will return two styles, Miyazaki and Pixar, which are multiplied into multiple dictionaries by the outer product.
|
92 |
+
Args:
|
93 |
+
dct (dict): 多属性组成的字典,可能有self.funcs关注的属性也可能没有,self.funcs按照各自的名字按需提取关注的属性和值,并转化成文本.
|
94 |
+
Dict of multiple attributes, may or may not have the attributes that self.funcs is concerned with. self.funcs extracts the attributes and values of interest according to their respective names and converts them into text.
|
95 |
+
ignored_blank_str (bool): 如果某个attr2text返回的是空字符串,是否要过滤掉该属性。默认`False`.
|
96 |
+
If the text returned by an attr2text is an empty string, whether to filter out the attribute. Defaults to `False`.
|
97 |
+
Returns:
|
98 |
+
Union[List[List[Tuple[str, str]]], List[Tuple[str, str]]: 多组多属性文本字典列表. Multiple sets of multi-attribute text dictionaries.
|
99 |
+
"""
|
100 |
+
attrs_lst = [[]]
|
101 |
+
for func in self.funcs:
|
102 |
+
if func.name in dct:
|
103 |
+
attrs = func(dct[func.name])
|
104 |
+
if isinstance(attrs, str):
|
105 |
+
for i in range(len(attrs_lst)):
|
106 |
+
attrs_lst[i].append((func.name, attrs))
|
107 |
+
else:
|
108 |
+
# 一个属性可能会返回多个文本
|
109 |
+
n_attrs = len(attrs)
|
110 |
+
new_attrs_lst = []
|
111 |
+
for n in range(n_attrs):
|
112 |
+
attrs_lst_cp = deepcopy(attrs_lst)
|
113 |
+
for i in range(len(attrs_lst_cp)):
|
114 |
+
attrs_lst_cp[i].append((func.name, attrs[n]))
|
115 |
+
new_attrs_lst.extend(attrs_lst_cp)
|
116 |
+
attrs_lst = new_attrs_lst
|
117 |
+
|
118 |
+
texts = [
|
119 |
+
[
|
120 |
+
(attr, text)
|
121 |
+
for (attr, text) in attrs
|
122 |
+
if not (text == "" and ignored_blank_str)
|
123 |
+
]
|
124 |
+
for attrs in attrs_lst
|
125 |
+
]
|
126 |
+
return texts
|
127 |
+
|
128 |
+
|
129 |
+
def format_tuple_texts(template: str, texts: Tuple[str, str]) -> str:
|
130 |
+
"""使用含有"{}" 的模板对多属性文本元组进行拼接,形成新文本
|
131 |
+
concatenate multiple attribute text tuples using a template containing "{}" to form a new text
|
132 |
+
Args:
|
133 |
+
template (str):
|
134 |
+
texts (Tuple[str, str]): 多属性文本元组. multiple attribute text tuples
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
str: 拼接后的新文本, merged new text
|
138 |
+
"""
|
139 |
+
merged_text = ", ".join([text[1] for text in texts if text[1] != ""])
|
140 |
+
merged_text = template.format(merged_text)
|
141 |
+
return merged_text
|
142 |
+
|
143 |
+
|
144 |
+
def format_dct_texts(template: str, texts: Dict[str, str]) -> str:
|
145 |
+
"""使用含有"{key}" 的模板对多属性文本字典进行拼接,形成新文本
|
146 |
+
concatenate multiple attribute text dictionaries using a template containing "{key}" to form a new text
|
147 |
+
Args:
|
148 |
+
template (str):
|
149 |
+
texts (Tuple[str, str]): 多属性文本字典. multiple attribute text dictionaries
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
str: 拼接后的新文本, merged new text
|
153 |
+
"""
|
154 |
+
merged_text = template.format(**texts)
|
155 |
+
return merged_text
|
156 |
+
|
157 |
+
|
158 |
+
def merge_multi_attrtext(texts: List[Tuple[str, str]], template: str = None) -> str:
|
159 |
+
"""对多属性文本元组进行拼接,形成新文本。
|
160 |
+
如果`template`含有{key},则根据key来取值;
|
161 |
+
如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。
|
162 |
+
|
163 |
+
concatenate multiple attribute text tuples to form a new text.
|
164 |
+
if `template` contains {key}, the value is taken according to the key;
|
165 |
+
if `template` contains only one {}, the values in texts are concatenated in order.
|
166 |
+
Args:
|
167 |
+
texts (List[Tuple[str, str]]): Tuple[str, str]第一个str是属性名,第二个str是属性转化的文本.
|
168 |
+
Tuple[str, str] The first str is the attribute name, and the second str is the text of the attribute conversion.
|
169 |
+
template (str, optional): template . Defaults to None.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
str: 拼接后的新文本, merged new text
|
173 |
+
"""
|
174 |
+
if not isinstance(texts, List):
|
175 |
+
texts = [texts]
|
176 |
+
if template is None or template == "":
|
177 |
+
template = "{}"
|
178 |
+
if has_key_brace(template):
|
179 |
+
texts = {k: v for k, v in texts}
|
180 |
+
merged_text = format_dct_texts(template, texts)
|
181 |
+
else:
|
182 |
+
merged_text = format_tuple_texts(template, texts)
|
183 |
+
return merged_text
|
184 |
+
|
185 |
+
|
186 |
+
class PresetMultiAttr2Text(MultiAttr2Text):
|
187 |
+
"""预置了多种关注属性转换的类,方便维护
|
188 |
+
class for multiple attribute conversion with multiple attention attributes preset for easy maintenance
|
189 |
+
|
190 |
+
"""
|
191 |
+
|
192 |
+
preset_attributes = []
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self, funcs: List = None, use_preset: bool = True, name: str = "preset"
|
196 |
+
) -> None:
|
197 |
+
"""虽然预置了关注的属性列表和转换类,但也允许定义示例时,进行更新。
|
198 |
+
注意`self.preset_attributes`的元素只是类名字,以便减少实例化的资源消耗。而funcs是实例化后的属性转换列表。
|
199 |
+
|
200 |
+
Although the list of attention attributes and conversion classes is preset, it is also allowed to be updated when defining an instance.
|
201 |
+
Note that the elements of `self.preset_attributes` are only class names, in order to reduce the resource consumption of instantiation. And funcs is a list of instantiated attribute conversions.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
funcs (List, optional): list of funcs . Defaults to None.
|
205 |
+
use_preset (bool, optional): _description_. Defaults to True.
|
206 |
+
name (str, optional): _description_. Defaults to "preset".
|
207 |
+
"""
|
208 |
+
if use_preset:
|
209 |
+
preset_funcs = self.preset()
|
210 |
+
else:
|
211 |
+
preset_funcs = []
|
212 |
+
if funcs is None:
|
213 |
+
funcs = []
|
214 |
+
if not isinstance(funcs, list):
|
215 |
+
funcs = [funcs]
|
216 |
+
funcs_names = [func.name for func in funcs]
|
217 |
+
preset_funcs = [
|
218 |
+
preset_func
|
219 |
+
for preset_func in preset_funcs
|
220 |
+
if preset_func.name not in funcs_names
|
221 |
+
]
|
222 |
+
funcs = funcs + preset_funcs
|
223 |
+
super().__init__(funcs, name)
|
224 |
+
|
225 |
+
def preset(self):
|
226 |
+
funcs = [cls() for cls in self.preset_attributes]
|
227 |
+
return funcs
|
musev/auto_prompt/attributes/human.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import json
|
5 |
+
|
6 |
+
from .attributes import (
|
7 |
+
MultiAttr2Text,
|
8 |
+
AttriributeIsText,
|
9 |
+
AttributeIsTextAndName,
|
10 |
+
PresetMultiAttr2Text,
|
11 |
+
)
|
12 |
+
from .style import Style
|
13 |
+
from .render import Render
|
14 |
+
from . import AttrRegister
|
15 |
+
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"Age",
|
19 |
+
"Sex",
|
20 |
+
"Singing",
|
21 |
+
"Country",
|
22 |
+
"Lighting",
|
23 |
+
"Headwear",
|
24 |
+
"Eyes",
|
25 |
+
"Irises",
|
26 |
+
"Hair",
|
27 |
+
"Skin",
|
28 |
+
"Face",
|
29 |
+
"Smile",
|
30 |
+
"Expression",
|
31 |
+
"Clothes",
|
32 |
+
"Nose",
|
33 |
+
"Mouth",
|
34 |
+
"Beard",
|
35 |
+
"Necklace",
|
36 |
+
"KeyWords",
|
37 |
+
"InsightFace",
|
38 |
+
"Caption",
|
39 |
+
"Env",
|
40 |
+
"Decoration",
|
41 |
+
"Festival",
|
42 |
+
"SpringHeadwear",
|
43 |
+
"SpringClothes",
|
44 |
+
"Animal",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
@AttrRegister.register
|
49 |
+
class Sex(AttriributeIsText):
|
50 |
+
name = "sex"
|
51 |
+
|
52 |
+
def __init__(self, name: str = None) -> None:
|
53 |
+
super().__init__(name)
|
54 |
+
|
55 |
+
|
56 |
+
@AttrRegister.register
|
57 |
+
class Headwear(AttriributeIsText):
|
58 |
+
name = "headwear"
|
59 |
+
|
60 |
+
def __init__(self, name: str = None) -> None:
|
61 |
+
super().__init__(name)
|
62 |
+
|
63 |
+
|
64 |
+
@AttrRegister.register
|
65 |
+
class Expression(AttriributeIsText):
|
66 |
+
name = "expression"
|
67 |
+
|
68 |
+
def __init__(self, name: str = None) -> None:
|
69 |
+
super().__init__(name)
|
70 |
+
|
71 |
+
|
72 |
+
@AttrRegister.register
|
73 |
+
class KeyWords(AttriributeIsText):
|
74 |
+
name = "keywords"
|
75 |
+
|
76 |
+
def __init__(self, name: str = None) -> None:
|
77 |
+
super().__init__(name)
|
78 |
+
|
79 |
+
|
80 |
+
@AttrRegister.register
|
81 |
+
class Singing(AttriributeIsText):
|
82 |
+
def __init__(self, name: str = "singing") -> None:
|
83 |
+
super().__init__(name)
|
84 |
+
|
85 |
+
|
86 |
+
@AttrRegister.register
|
87 |
+
class Country(AttriributeIsText):
|
88 |
+
name = "country"
|
89 |
+
|
90 |
+
def __init__(self, name: str = None) -> None:
|
91 |
+
super().__init__(name)
|
92 |
+
|
93 |
+
|
94 |
+
@AttrRegister.register
|
95 |
+
class Clothes(AttriributeIsText):
|
96 |
+
name = "clothes"
|
97 |
+
|
98 |
+
def __init__(self, name: str = None) -> None:
|
99 |
+
super().__init__(name)
|
100 |
+
|
101 |
+
|
102 |
+
@AttrRegister.register
|
103 |
+
class Age(AttributeIsTextAndName):
|
104 |
+
name = "age"
|
105 |
+
|
106 |
+
def __init__(self, name: str = None) -> None:
|
107 |
+
super().__init__(name)
|
108 |
+
|
109 |
+
def __call__(self, attributes: str) -> str:
|
110 |
+
if not isinstance(attributes, str):
|
111 |
+
attributes = str(attributes)
|
112 |
+
attributes = attributes.split(",")
|
113 |
+
text = ", ".join(
|
114 |
+
["{}-year-old".format(attr) if attr != "" else "" for attr in attributes]
|
115 |
+
)
|
116 |
+
return text
|
117 |
+
|
118 |
+
|
119 |
+
@AttrRegister.register
|
120 |
+
class Eyes(AttributeIsTextAndName):
|
121 |
+
name = "eyes"
|
122 |
+
|
123 |
+
def __init__(self, name: str = None) -> None:
|
124 |
+
super().__init__(name)
|
125 |
+
|
126 |
+
|
127 |
+
@AttrRegister.register
|
128 |
+
class Hair(AttributeIsTextAndName):
|
129 |
+
name = "hair"
|
130 |
+
|
131 |
+
def __init__(self, name: str = None) -> None:
|
132 |
+
super().__init__(name)
|
133 |
+
|
134 |
+
|
135 |
+
@AttrRegister.register
|
136 |
+
class Background(AttributeIsTextAndName):
|
137 |
+
name = "background"
|
138 |
+
|
139 |
+
def __init__(self, name: str = None) -> None:
|
140 |
+
super().__init__(name)
|
141 |
+
|
142 |
+
|
143 |
+
@AttrRegister.register
|
144 |
+
class Skin(AttributeIsTextAndName):
|
145 |
+
name = "skin"
|
146 |
+
|
147 |
+
def __init__(self, name: str = None) -> None:
|
148 |
+
super().__init__(name)
|
149 |
+
|
150 |
+
|
151 |
+
@AttrRegister.register
|
152 |
+
class Face(AttributeIsTextAndName):
|
153 |
+
name = "face"
|
154 |
+
|
155 |
+
def __init__(self, name: str = None) -> None:
|
156 |
+
super().__init__(name)
|
157 |
+
|
158 |
+
|
159 |
+
@AttrRegister.register
|
160 |
+
class Smile(AttributeIsTextAndName):
|
161 |
+
name = "smile"
|
162 |
+
|
163 |
+
def __init__(self, name: str = None) -> None:
|
164 |
+
super().__init__(name)
|
165 |
+
|
166 |
+
|
167 |
+
@AttrRegister.register
|
168 |
+
class Nose(AttributeIsTextAndName):
|
169 |
+
name = "nose"
|
170 |
+
|
171 |
+
def __init__(self, name: str = None) -> None:
|
172 |
+
super().__init__(name)
|
173 |
+
|
174 |
+
|
175 |
+
@AttrRegister.register
|
176 |
+
class Mouth(AttributeIsTextAndName):
|
177 |
+
name = "mouth"
|
178 |
+
|
179 |
+
def __init__(self, name: str = None) -> None:
|
180 |
+
super().__init__(name)
|
181 |
+
|
182 |
+
|
183 |
+
@AttrRegister.register
|
184 |
+
class Beard(AttriributeIsText):
|
185 |
+
name = "beard"
|
186 |
+
|
187 |
+
def __init__(self, name: str = None) -> None:
|
188 |
+
super().__init__(name)
|
189 |
+
|
190 |
+
|
191 |
+
@AttrRegister.register
|
192 |
+
class Necklace(AttributeIsTextAndName):
|
193 |
+
name = "necklace"
|
194 |
+
|
195 |
+
def __init__(self, name: str = None) -> None:
|
196 |
+
super().__init__(name)
|
197 |
+
|
198 |
+
|
199 |
+
@AttrRegister.register
|
200 |
+
class Irises(AttributeIsTextAndName):
|
201 |
+
name = "irises"
|
202 |
+
|
203 |
+
def __init__(self, name: str = None) -> None:
|
204 |
+
super().__init__(name)
|
205 |
+
|
206 |
+
|
207 |
+
@AttrRegister.register
|
208 |
+
class Lighting(AttributeIsTextAndName):
|
209 |
+
name = "lighting"
|
210 |
+
|
211 |
+
def __init__(self, name: str = None) -> None:
|
212 |
+
super().__init__(name)
|
213 |
+
|
214 |
+
|
215 |
+
PresetPortraitAttributes = [
|
216 |
+
Age,
|
217 |
+
Sex,
|
218 |
+
Singing,
|
219 |
+
Country,
|
220 |
+
Lighting,
|
221 |
+
Headwear,
|
222 |
+
Eyes,
|
223 |
+
Irises,
|
224 |
+
Hair,
|
225 |
+
Skin,
|
226 |
+
Face,
|
227 |
+
Smile,
|
228 |
+
Expression,
|
229 |
+
Clothes,
|
230 |
+
Nose,
|
231 |
+
Mouth,
|
232 |
+
Beard,
|
233 |
+
Necklace,
|
234 |
+
Style,
|
235 |
+
KeyWords,
|
236 |
+
Render,
|
237 |
+
]
|
238 |
+
|
239 |
+
|
240 |
+
class PortraitMultiAttr2Text(PresetMultiAttr2Text):
|
241 |
+
preset_attributes = PresetPortraitAttributes
|
242 |
+
|
243 |
+
def __init__(self, funcs: list = None, use_preset=True, name="portrait") -> None:
|
244 |
+
super().__init__(funcs, use_preset, name)
|
245 |
+
|
246 |
+
|
247 |
+
@AttrRegister.register
|
248 |
+
class InsightFace(AttriributeIsText):
|
249 |
+
name = "insight_face"
|
250 |
+
face_render_dict = {
|
251 |
+
"boy": "handsome,elegant",
|
252 |
+
"girl": "gorgeous,kawaii,colorful",
|
253 |
+
}
|
254 |
+
key_words = "delicate face,beautiful eyes"
|
255 |
+
|
256 |
+
def __call__(self, attributes: str) -> str:
|
257 |
+
"""将insight faces 检测的结果转化成prompt
|
258 |
+
convert the results of insight faces detection to prompt
|
259 |
+
Args:
|
260 |
+
face_list (_type_): _description_
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
_type_: _description_
|
264 |
+
"""
|
265 |
+
attributes = json.loads(attributes)
|
266 |
+
face_list = attributes["info"]
|
267 |
+
if len(face_list) == 0:
|
268 |
+
return ""
|
269 |
+
|
270 |
+
if attributes["image_type"] == "body":
|
271 |
+
for face in face_list:
|
272 |
+
if "black" in face and face["black"]:
|
273 |
+
return "african,dark skin"
|
274 |
+
return ""
|
275 |
+
|
276 |
+
gender_dict = {"girl": 0, "boy": 0}
|
277 |
+
face_render_list = []
|
278 |
+
black = False
|
279 |
+
|
280 |
+
for face in face_list:
|
281 |
+
if face["ratio"] < 0.02:
|
282 |
+
continue
|
283 |
+
|
284 |
+
if face["gender"] == 0:
|
285 |
+
gender_dict["girl"] += 1
|
286 |
+
face_render_list.append(self.face_render_dict["girl"])
|
287 |
+
else:
|
288 |
+
gender_dict["boy"] += 1
|
289 |
+
face_render_list.append(self.face_render_dict["boy"])
|
290 |
+
|
291 |
+
if "black" in face and face["black"]:
|
292 |
+
black = True
|
293 |
+
|
294 |
+
if len(face_render_list) == 0:
|
295 |
+
return ""
|
296 |
+
elif len(face_render_list) == 1:
|
297 |
+
solo = True
|
298 |
+
else:
|
299 |
+
solo = False
|
300 |
+
|
301 |
+
gender = ""
|
302 |
+
for g, num in gender_dict.items():
|
303 |
+
if num > 0:
|
304 |
+
if gender:
|
305 |
+
gender += ", "
|
306 |
+
gender += "{}{}".format(num, g)
|
307 |
+
if num > 1:
|
308 |
+
gender += "s"
|
309 |
+
|
310 |
+
face_render_list = ",".join(face_render_list)
|
311 |
+
face_render_list = face_render_list.split(",")
|
312 |
+
face_render = list(set(face_render_list))
|
313 |
+
face_render.sort(key=face_render_list.index)
|
314 |
+
face_render = ",".join(face_render)
|
315 |
+
if gender_dict["girl"] == 0:
|
316 |
+
face_render = "male focus," + face_render
|
317 |
+
|
318 |
+
insightface_prompt = "{},{},{}".format(gender, face_render, self.key_words)
|
319 |
+
|
320 |
+
if solo:
|
321 |
+
insightface_prompt += ",solo"
|
322 |
+
if black:
|
323 |
+
insightface_prompt = "african,dark skin," + insightface_prompt
|
324 |
+
|
325 |
+
return insightface_prompt
|
326 |
+
|
327 |
+
|
328 |
+
@AttrRegister.register
|
329 |
+
class Caption(AttriributeIsText):
|
330 |
+
name = "caption"
|
331 |
+
|
332 |
+
|
333 |
+
@AttrRegister.register
|
334 |
+
class Env(AttriributeIsText):
|
335 |
+
name = "env"
|
336 |
+
envs_list = [
|
337 |
+
"east asian architecture",
|
338 |
+
"fireworks",
|
339 |
+
"snow, snowflakes",
|
340 |
+
"snowing, snowflakes",
|
341 |
+
]
|
342 |
+
|
343 |
+
def __call__(self, attributes: str = None) -> str:
|
344 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
345 |
+
return attributes
|
346 |
+
else:
|
347 |
+
return random.choice(self.envs_list)
|
348 |
+
|
349 |
+
|
350 |
+
@AttrRegister.register
|
351 |
+
class Decoration(AttriributeIsText):
|
352 |
+
name = "decoration"
|
353 |
+
|
354 |
+
def __init__(self, name: str = None) -> None:
|
355 |
+
self.decoration_list = [
|
356 |
+
"chinese knot",
|
357 |
+
"flowers",
|
358 |
+
"food",
|
359 |
+
"lanterns",
|
360 |
+
"red envelop",
|
361 |
+
]
|
362 |
+
super().__init__(name)
|
363 |
+
|
364 |
+
def __call__(self, attributes: str = None) -> str:
|
365 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
366 |
+
return attributes
|
367 |
+
else:
|
368 |
+
return random.choice(self.decoration_list)
|
369 |
+
|
370 |
+
|
371 |
+
@AttrRegister.register
|
372 |
+
class Festival(AttriributeIsText):
|
373 |
+
name = "festival"
|
374 |
+
festival_list = ["new year"]
|
375 |
+
|
376 |
+
def __init__(self, name: str = None) -> None:
|
377 |
+
super().__init__(name)
|
378 |
+
|
379 |
+
def __call__(self, attributes: str = None) -> str:
|
380 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
381 |
+
return attributes
|
382 |
+
else:
|
383 |
+
return random.choice(self.festival_list)
|
384 |
+
|
385 |
+
|
386 |
+
@AttrRegister.register
|
387 |
+
class SpringHeadwear(AttriributeIsText):
|
388 |
+
name = "spring_headwear"
|
389 |
+
headwear_list = ["rabbit ears", "rabbit ears, fur hat"]
|
390 |
+
|
391 |
+
def __call__(self, attributes: str = None) -> str:
|
392 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
393 |
+
return attributes
|
394 |
+
else:
|
395 |
+
return random.choice(self.headwear_list)
|
396 |
+
|
397 |
+
|
398 |
+
@AttrRegister.register
|
399 |
+
class SpringClothes(AttriributeIsText):
|
400 |
+
name = "spring_clothes"
|
401 |
+
clothes_list = [
|
402 |
+
"mittens,chinese clothes",
|
403 |
+
"mittens,fur trim",
|
404 |
+
"mittens,red scarf",
|
405 |
+
"mittens,winter clothes",
|
406 |
+
]
|
407 |
+
|
408 |
+
def __call__(self, attributes: str = None) -> str:
|
409 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
410 |
+
return attributes
|
411 |
+
else:
|
412 |
+
return random.choice(self.clothes_list)
|
413 |
+
|
414 |
+
|
415 |
+
@AttrRegister.register
|
416 |
+
class Animal(AttriributeIsText):
|
417 |
+
name = "animal"
|
418 |
+
animal_list = ["rabbit", "holding rabbits"]
|
419 |
+
|
420 |
+
def __call__(self, attributes: str = None) -> str:
|
421 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
422 |
+
return attributes
|
423 |
+
else:
|
424 |
+
return random.choice(self.animal_list)
|
musev/auto_prompt/attributes/render.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcm.utils.util import flatten
|
2 |
+
|
3 |
+
from .attributes import BaseAttribute2Text
|
4 |
+
from . import AttrRegister
|
5 |
+
|
6 |
+
__all__ = ["Render"]
|
7 |
+
|
8 |
+
RenderMap = {
|
9 |
+
"Epic": "artstation, epic environment, highly detailed, 8k, HD",
|
10 |
+
"HD": "8k, highly detailed",
|
11 |
+
"EpicHD": "hyper detailed, beautiful lighting, epic environment, octane render, cinematic, 8k",
|
12 |
+
"Digital": "detailed illustration, crisp lines, digital art, 8k, trending on artstation",
|
13 |
+
"Unreal1": "artstation, concept art, smooth, sharp focus, illustration, unreal engine 5, 8k",
|
14 |
+
"Unreal2": "concept art, octane render, artstation, epic environment, highly detailed, 8k",
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
@AttrRegister.register
|
19 |
+
class Render(BaseAttribute2Text):
|
20 |
+
name = "render"
|
21 |
+
|
22 |
+
def __init__(self, name: str = None) -> None:
|
23 |
+
super().__init__(name)
|
24 |
+
|
25 |
+
def __call__(self, attributes: str) -> str:
|
26 |
+
if attributes == "" or attributes is None:
|
27 |
+
return ""
|
28 |
+
attributes = attributes.split(",")
|
29 |
+
render = [RenderMap[attr] for attr in attributes if attr in RenderMap]
|
30 |
+
render = flatten(render, ignored_iterable_types=[str])
|
31 |
+
if len(render) == 1:
|
32 |
+
render = render[0]
|
33 |
+
return render
|
musev/auto_prompt/attributes/style.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .attributes import AttriributeIsText
|
2 |
+
from . import AttrRegister
|
3 |
+
|
4 |
+
__all__ = ["Style"]
|
5 |
+
|
6 |
+
|
7 |
+
@AttrRegister.register
|
8 |
+
class Style(AttriributeIsText):
|
9 |
+
name = "style"
|
10 |
+
|
11 |
+
def __init__(self, name: str = None) -> None:
|
12 |
+
super().__init__(name)
|
musev/auto_prompt/human.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""负责按照人相关的属性转化成提词
|
2 |
+
"""
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from .attributes.human import PortraitMultiAttr2Text
|
6 |
+
from .attributes.attributes import BaseAttribute2Text
|
7 |
+
from .attributes.attr2template import MultiAttr2PromptTemplate
|
8 |
+
|
9 |
+
|
10 |
+
class PortraitAttr2PromptTemplate(MultiAttr2PromptTemplate):
|
11 |
+
"""可以将任务字典转化为形象提词模板类
|
12 |
+
template class for converting task dictionaries into image prompt templates
|
13 |
+
Args:
|
14 |
+
MultiAttr2PromptTemplate (_type_): _description_
|
15 |
+
"""
|
16 |
+
|
17 |
+
templates = "a portrait of {}"
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self, templates: str = None, attr2text: List = None, name: str = "portrait"
|
21 |
+
) -> None:
|
22 |
+
"""
|
23 |
+
|
24 |
+
Args:
|
25 |
+
templates (str, optional): 形象提词模板,若为None,则使用默认的类属性. Defaults to None.
|
26 |
+
portrait prompt template, if None, the default class attribute is used.
|
27 |
+
attr2text (List, optional): 形象类需要新增、更新的属性列表,默认使用PortraitMultiAttr2Text中定义的形象属性. Defaults to None.
|
28 |
+
the list of attributes that need to be added or updated in the image class, by default, the image attributes defined in PortraitMultiAttr2Text are used.
|
29 |
+
name (str, optional): 该形象类的名字. Defaults to "portrait".
|
30 |
+
class name of this class instance
|
31 |
+
"""
|
32 |
+
if (
|
33 |
+
attr2text is None
|
34 |
+
or isinstance(attr2text, list)
|
35 |
+
or isinstance(attr2text, BaseAttribute2Text)
|
36 |
+
):
|
37 |
+
attr2text = PortraitMultiAttr2Text(funcs=attr2text)
|
38 |
+
if templates is None:
|
39 |
+
templates = self.templates
|
40 |
+
super().__init__(templates, attr2text, name=name)
|
musev/auto_prompt/load_template.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcm.utils.str_util import has_key_brace
|
2 |
+
|
3 |
+
from .human import PortraitAttr2PromptTemplate
|
4 |
+
from .attributes.attr2template import (
|
5 |
+
KeywordMultiAttr2PromptTemplate,
|
6 |
+
OnlySpacePromptTemplate,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
def get_template_by_name(template: str, name: str = None):
|
11 |
+
"""根据 template_name 确定 prompt 生成器类
|
12 |
+
choose prompt generator class according to template_name
|
13 |
+
Args:
|
14 |
+
name (str): template 的名字简称,便于指定. template name abbreviation, for easy reference
|
15 |
+
|
16 |
+
Raises:
|
17 |
+
ValueError: ValueError: 如果name不在支持的列表中,则报错. if name is not in the supported list, an error is reported.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
MultiAttr2PromptTemplate: 能够将任务字典转化为提词的 实现了__call__功能的类. class that can convert task dictionaries into prompts and implements the __call__ function
|
21 |
+
|
22 |
+
"""
|
23 |
+
if template == "" or template is None:
|
24 |
+
template = OnlySpacePromptTemplate(template=template)
|
25 |
+
elif has_key_brace(template):
|
26 |
+
# if has_key_brace(template):
|
27 |
+
template = KeywordMultiAttr2PromptTemplate(template=template)
|
28 |
+
else:
|
29 |
+
if name == "portrait":
|
30 |
+
template = PortraitAttr2PromptTemplate(templates=template)
|
31 |
+
else:
|
32 |
+
raise ValueError(
|
33 |
+
"PresetAttr2PromptTemplate only support one of [portrait], but given {}".format(
|
34 |
+
name
|
35 |
+
)
|
36 |
+
)
|
37 |
+
return template
|
musev/auto_prompt/util.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import Dict, List
|
3 |
+
|
4 |
+
from .load_template import get_template_by_name
|
5 |
+
|
6 |
+
|
7 |
+
def generate_prompts(tasks: List[Dict]) -> List[Dict]:
|
8 |
+
new_tasks = []
|
9 |
+
for task in tasks:
|
10 |
+
task["origin_prompt"] = deepcopy(task["prompt"])
|
11 |
+
# 如果prompt单元值含有模板 {},或者 没有填写任何值(默认为空模板),则使用原prompt值
|
12 |
+
if "{" not in task["prompt"] and len(task["prompt"]) != 0:
|
13 |
+
new_tasks.append(task)
|
14 |
+
else:
|
15 |
+
template = get_template_by_name(
|
16 |
+
template=task["prompt"], name=task.get("template_name", None)
|
17 |
+
)
|
18 |
+
prompts = template(task)
|
19 |
+
if not isinstance(prompts, list) and isinstance(prompts, str):
|
20 |
+
prompts = [prompts]
|
21 |
+
for prompt in prompts:
|
22 |
+
task_cp = deepcopy(task)
|
23 |
+
task_cp["prompt"] = prompt
|
24 |
+
new_tasks.append(task_cp)
|
25 |
+
return new_tasks
|
musev/data/__init__.py
ADDED
File without changes
|
musev/data/data_util.py
ADDED
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Literal, Union, Tuple
|
2 |
+
import os
|
3 |
+
import string
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def generate_tasks_of_dir(
|
14 |
+
path: str,
|
15 |
+
output_dir: str,
|
16 |
+
exts: Tuple[str],
|
17 |
+
same_dir_name: bool = False,
|
18 |
+
**kwargs,
|
19 |
+
) -> List[Dict]:
|
20 |
+
"""covert video directory into tasks
|
21 |
+
|
22 |
+
Args:
|
23 |
+
path (str): _description_
|
24 |
+
output_dir (str): _description_
|
25 |
+
exts (Tuple[str]): _description_
|
26 |
+
same_dir_name (bool, optional): 存储路径是否保留和源视频相同的父文件名. Defaults to False.
|
27 |
+
whether keep the same parent dir name as the source video
|
28 |
+
Returns:
|
29 |
+
List[Dict]: _description_
|
30 |
+
"""
|
31 |
+
tasks = []
|
32 |
+
for rootdir, dirs, files in os.walk(path):
|
33 |
+
for basename in files:
|
34 |
+
if basename.lower().endswith(exts):
|
35 |
+
video_path = os.path.join(rootdir, basename)
|
36 |
+
filename, ext = basename.split(".")
|
37 |
+
rootdir_name = os.path.basename(rootdir)
|
38 |
+
if same_dir_name:
|
39 |
+
save_path = os.path.join(
|
40 |
+
output_dir, rootdir_name, f"{filename}.h5py"
|
41 |
+
)
|
42 |
+
save_dir = os.path.join(output_dir, rootdir_name)
|
43 |
+
else:
|
44 |
+
save_path = os.path.join(output_dir, f"{filename}.h5py")
|
45 |
+
save_dir = output_dir
|
46 |
+
task = {
|
47 |
+
"video_path": video_path,
|
48 |
+
"output_path": save_path,
|
49 |
+
"output_dir": save_dir,
|
50 |
+
"filename": filename,
|
51 |
+
"ext": ext,
|
52 |
+
}
|
53 |
+
task.update(kwargs)
|
54 |
+
tasks.append(task)
|
55 |
+
return tasks
|
56 |
+
|
57 |
+
|
58 |
+
def sample_by_idx(
|
59 |
+
T: int,
|
60 |
+
n_sample: int,
|
61 |
+
sample_rate: int,
|
62 |
+
sample_start_idx: int = None,
|
63 |
+
change_sample_rate: bool = False,
|
64 |
+
seed: int = None,
|
65 |
+
whether_random: bool = True,
|
66 |
+
n_independent: int = 0,
|
67 |
+
) -> List[int]:
|
68 |
+
"""given a int to represent candidate list, sample n_sample with sample_rate from the candidate list
|
69 |
+
|
70 |
+
Args:
|
71 |
+
T (int): _description_
|
72 |
+
n_sample (int): 目标采样数目. sample number
|
73 |
+
sample_rate (int): 采样率, 每隔sample_rate个采样一个. sample interval, pick one per sample_rate number
|
74 |
+
sample_start_idx (int, optional): 采样开始位置的选择. start position to sample . Defaults to 0.
|
75 |
+
change_sample_rate (bool, optional): 是否可以通过降低sample_rate的方式来完成采样. whether allow changing sample_rate to finish sample process. Defaults to False.
|
76 |
+
whether_random (bool, optional): 是否最后随机选择开始点. whether randomly choose sample start position. Defaults to False.
|
77 |
+
|
78 |
+
Raises:
|
79 |
+
ValueError: T / sample_rate should be larger than n_sample
|
80 |
+
Returns:
|
81 |
+
List[int]: 采样的索引位置. sampled index position
|
82 |
+
"""
|
83 |
+
if T < n_sample:
|
84 |
+
raise ValueError(f"T({T}) < n_sample({n_sample})")
|
85 |
+
else:
|
86 |
+
if T / sample_rate < n_sample:
|
87 |
+
if not change_sample_rate:
|
88 |
+
raise ValueError(
|
89 |
+
f"T({T}) / sample_rate({sample_rate}) < n_sample({n_sample})"
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
while T / sample_rate < n_sample:
|
93 |
+
sample_rate -= 1
|
94 |
+
logger.error(
|
95 |
+
f"sample_rate{sample_rate+1} is too large, decrease to {sample_rate}"
|
96 |
+
)
|
97 |
+
if sample_rate == 0:
|
98 |
+
raise ValueError("T / sample_rate < n_sample")
|
99 |
+
|
100 |
+
if sample_start_idx is None:
|
101 |
+
if whether_random:
|
102 |
+
sample_start_idx_candidates = np.arange(T - n_sample * sample_rate)
|
103 |
+
if seed is not None:
|
104 |
+
np.random.seed(seed)
|
105 |
+
sample_start_idx = np.random.choice(sample_start_idx_candidates, 1)[0]
|
106 |
+
|
107 |
+
else:
|
108 |
+
sample_start_idx = 0
|
109 |
+
sample_end_idx = sample_start_idx + sample_rate * n_sample
|
110 |
+
sample = list(range(sample_start_idx, sample_end_idx, sample_rate))
|
111 |
+
if n_independent == 0:
|
112 |
+
n_independent_sample = None
|
113 |
+
else:
|
114 |
+
left_candidate = np.array(
|
115 |
+
list(range(0, sample_start_idx)) + list(range(sample_end_idx, T))
|
116 |
+
)
|
117 |
+
if len(left_candidate) >= n_independent:
|
118 |
+
# 使用两端的剩余空间采样, use the left space to sample
|
119 |
+
n_independent_sample = np.random.choice(left_candidate, n_independent)
|
120 |
+
else:
|
121 |
+
# 当两端没有剩余采样空间时,使用任意不是sample中的帧
|
122 |
+
# if no enough space to sample, use any frame not in sample
|
123 |
+
left_candidate = np.array(list(set(range(T) - set(sample))))
|
124 |
+
n_independent_sample = np.random.choice(left_candidate, n_independent)
|
125 |
+
|
126 |
+
return sample, sample_rate, n_independent_sample
|
127 |
+
|
128 |
+
|
129 |
+
def sample_tensor_by_idx(
|
130 |
+
tensor: Union[torch.Tensor, np.ndarray],
|
131 |
+
n_sample: int,
|
132 |
+
sample_rate: int,
|
133 |
+
sample_start_idx: int = 0,
|
134 |
+
change_sample_rate: bool = False,
|
135 |
+
seed: int = None,
|
136 |
+
dim: int = 0,
|
137 |
+
return_type: Literal["numpy", "torch"] = "torch",
|
138 |
+
whether_random: bool = True,
|
139 |
+
n_independent: int = 0,
|
140 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]:
|
141 |
+
"""sample sub_tensor
|
142 |
+
|
143 |
+
Args:
|
144 |
+
tensor (Union[torch.Tensor, np.ndarray]): _description_
|
145 |
+
n_sample (int): _description_
|
146 |
+
sample_rate (int): _description_
|
147 |
+
sample_start_idx (int, optional): _description_. Defaults to 0.
|
148 |
+
change_sample_rate (bool, optional): _description_. Defaults to False.
|
149 |
+
seed (int, optional): _description_. Defaults to None.
|
150 |
+
dim (int, optional): _description_. Defaults to 0.
|
151 |
+
return_type (Literal["numpy", "torch"], optional): _description_. Defaults to "torch".
|
152 |
+
whether_random (bool, optional): _description_. Defaults to True.
|
153 |
+
n_independent (int, optional): 独立于n_sample的采样数量. Defaults to 0.
|
154 |
+
n_independent sample number that is independent of n_sample
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: sampled tensor
|
158 |
+
"""
|
159 |
+
if isinstance(tensor, np.ndarray):
|
160 |
+
tensor = torch.from_numpy(tensor)
|
161 |
+
T = tensor.shape[dim]
|
162 |
+
sample_idx, sample_rate, independent_sample_idx = sample_by_idx(
|
163 |
+
T,
|
164 |
+
n_sample,
|
165 |
+
sample_rate,
|
166 |
+
sample_start_idx,
|
167 |
+
change_sample_rate,
|
168 |
+
seed,
|
169 |
+
whether_random=whether_random,
|
170 |
+
n_independent=n_independent,
|
171 |
+
)
|
172 |
+
sample_idx = torch.LongTensor(sample_idx)
|
173 |
+
sample = torch.index_select(tensor, dim, sample_idx)
|
174 |
+
if independent_sample_idx is not None:
|
175 |
+
independent_sample_idx = torch.LongTensor(independent_sample_idx)
|
176 |
+
independent_sample = torch.index_select(tensor, dim, independent_sample_idx)
|
177 |
+
else:
|
178 |
+
independent_sample = None
|
179 |
+
independent_sample_idx = None
|
180 |
+
if return_type == "numpy":
|
181 |
+
sample = sample.cpu().numpy()
|
182 |
+
return sample, sample_idx, sample_rate, independent_sample, independent_sample_idx
|
183 |
+
|
184 |
+
|
185 |
+
def concat_two_tensor(
|
186 |
+
data1: torch.Tensor,
|
187 |
+
data2: torch.Tensor,
|
188 |
+
dim: int,
|
189 |
+
method: Literal[
|
190 |
+
"first_in_first_out", "first_in_last_out", "intertwine", "index"
|
191 |
+
] = "first_in_first_out",
|
192 |
+
data1_index: torch.long = None,
|
193 |
+
data2_index: torch.long = None,
|
194 |
+
return_index: bool = False,
|
195 |
+
):
|
196 |
+
"""concat two tensor along dim with given method
|
197 |
+
|
198 |
+
Args:
|
199 |
+
data1 (torch.Tensor): first in data
|
200 |
+
data2 (torch.Tensor): last in data
|
201 |
+
dim (int): _description_
|
202 |
+
method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine" ], optional): _description_. Defaults to "first_in_first_out".
|
203 |
+
|
204 |
+
Raises:
|
205 |
+
NotImplementedError: unsupported method
|
206 |
+
ValueError: unsupported method
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
_type_: _description_
|
210 |
+
"""
|
211 |
+
len_data1 = data1.shape[dim]
|
212 |
+
len_data2 = data2.shape[dim]
|
213 |
+
|
214 |
+
if method == "first_in_first_out":
|
215 |
+
res = torch.concat([data1, data2], dim=dim)
|
216 |
+
data1_index = range(len_data1)
|
217 |
+
data2_index = [len_data1 + x for x in range(len_data2)]
|
218 |
+
elif method == "first_in_last_out":
|
219 |
+
res = torch.concat([data2, data1], dim=dim)
|
220 |
+
data2_index = range(len_data2)
|
221 |
+
data1_index = [len_data2 + x for x in range(len_data1)]
|
222 |
+
elif method == "intertwine":
|
223 |
+
raise NotImplementedError("intertwine")
|
224 |
+
elif method == "index":
|
225 |
+
res = concat_two_tensor_with_index(
|
226 |
+
data1=data1,
|
227 |
+
data1_index=data1_index,
|
228 |
+
data2=data2,
|
229 |
+
data2_index=data2_index,
|
230 |
+
dim=dim,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
raise ValueError(
|
234 |
+
"only support first_in_first_out, first_in_last_out, intertwine, index"
|
235 |
+
)
|
236 |
+
if return_index:
|
237 |
+
return res, data1_index, data2_index
|
238 |
+
else:
|
239 |
+
return res
|
240 |
+
|
241 |
+
|
242 |
+
def concat_two_tensor_with_index(
|
243 |
+
data1: torch.Tensor,
|
244 |
+
data1_index: torch.LongTensor,
|
245 |
+
data2: torch.Tensor,
|
246 |
+
data2_index: torch.LongTensor,
|
247 |
+
dim: int,
|
248 |
+
) -> torch.Tensor:
|
249 |
+
"""_summary_
|
250 |
+
|
251 |
+
Args:
|
252 |
+
data1 (torch.Tensor): b1*c1*h1*w1*...
|
253 |
+
data1_index (torch.LongTensor): N, if dim=1, N=c1
|
254 |
+
data2 (torch.Tensor): b2*c2*h2*w2*...
|
255 |
+
data2_index (torch.LongTensor): M, if dim=1, M=c2
|
256 |
+
dim (int): int
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
torch.Tensor: b*c*h*w*..., if dim=1, b=b1=b2, c=c1+c2, h=h1=h2, w=w1=w2,...
|
260 |
+
"""
|
261 |
+
shape1 = list(data1.shape)
|
262 |
+
shape2 = list(data2.shape)
|
263 |
+
target_shape = list(shape1)
|
264 |
+
target_shape[dim] = shape1[dim] + shape2[dim]
|
265 |
+
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
|
266 |
+
target = batch_index_copy(target, dim=dim, index=data1_index, source=data1)
|
267 |
+
target = batch_index_copy(target, dim=dim, index=data2_index, source=data2)
|
268 |
+
return target
|
269 |
+
|
270 |
+
|
271 |
+
def repeat_index_to_target_size(
|
272 |
+
index: torch.LongTensor, target_size: int
|
273 |
+
) -> torch.LongTensor:
|
274 |
+
if len(index.shape) == 1:
|
275 |
+
index = repeat(index, "n -> b n", b=target_size)
|
276 |
+
if len(index.shape) == 2:
|
277 |
+
remainder = target_size % index.shape[0]
|
278 |
+
assert (
|
279 |
+
remainder == 0
|
280 |
+
), f"target_size % index.shape[0] must be zero, but give {target_size % index.shape[0]}"
|
281 |
+
index = repeat(index, "b n -> (b c) n", c=int(target_size / index.shape[0]))
|
282 |
+
return index
|
283 |
+
|
284 |
+
|
285 |
+
def batch_concat_two_tensor_with_index(
|
286 |
+
data1: torch.Tensor,
|
287 |
+
data1_index: torch.LongTensor,
|
288 |
+
data2: torch.Tensor,
|
289 |
+
data2_index: torch.LongTensor,
|
290 |
+
dim: int,
|
291 |
+
) -> torch.Tensor:
|
292 |
+
return concat_two_tensor_with_index(data1, data1_index, data2, data2_index, dim)
|
293 |
+
|
294 |
+
|
295 |
+
def interwine_two_tensor(
|
296 |
+
data1: torch.Tensor,
|
297 |
+
data2: torch.Tensor,
|
298 |
+
dim: int,
|
299 |
+
return_index: bool = False,
|
300 |
+
) -> torch.Tensor:
|
301 |
+
shape1 = list(data1.shape)
|
302 |
+
shape2 = list(data2.shape)
|
303 |
+
target_shape = list(shape1)
|
304 |
+
target_shape[dim] = shape1[dim] + shape2[dim]
|
305 |
+
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
|
306 |
+
data1_reshape = torch.swapaxes(data1, 0, dim)
|
307 |
+
data2_reshape = torch.swapaxes(data2, 0, dim)
|
308 |
+
target = torch.swapaxes(target, 0, dim)
|
309 |
+
total_index = set(range(target_shape[dim]))
|
310 |
+
data1_index = range(0, 2 * shape1[dim], 2)
|
311 |
+
data2_index = sorted(list(set(total_index) - set(data1_index)))
|
312 |
+
data1_index = torch.LongTensor(data1_index)
|
313 |
+
data2_index = torch.LongTensor(data2_index)
|
314 |
+
target[data1_index, ...] = data1_reshape
|
315 |
+
target[data2_index, ...] = data2_reshape
|
316 |
+
target = torch.swapaxes(target, 0, dim)
|
317 |
+
if return_index:
|
318 |
+
return target, data1_index, data2_index
|
319 |
+
else:
|
320 |
+
return target
|
321 |
+
|
322 |
+
|
323 |
+
def split_index(
|
324 |
+
indexs: torch.Tensor,
|
325 |
+
n_first: int = None,
|
326 |
+
n_last: int = None,
|
327 |
+
method: Literal[
|
328 |
+
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
|
329 |
+
] = "first_in_first_out",
|
330 |
+
):
|
331 |
+
"""_summary_
|
332 |
+
|
333 |
+
Args:
|
334 |
+
indexs (List): _description_
|
335 |
+
n_first (int): _description_
|
336 |
+
n_last (int): _description_
|
337 |
+
method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine", "index" ], optional): _description_. Defaults to "first_in_first_out".
|
338 |
+
|
339 |
+
Raises:
|
340 |
+
NotImplementedError: _description_
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
first_index: _description_
|
344 |
+
last_index:
|
345 |
+
"""
|
346 |
+
# assert (
|
347 |
+
# n_first is None and n_last is None
|
348 |
+
# ), "must assign one value for n_first or n_last"
|
349 |
+
n_total = len(indexs)
|
350 |
+
if n_first is None:
|
351 |
+
n_first = n_total - n_last
|
352 |
+
if n_last is None:
|
353 |
+
n_last = n_total - n_first
|
354 |
+
assert len(indexs) == n_first + n_last
|
355 |
+
if method == "first_in_first_out":
|
356 |
+
first_index = indexs[:n_first]
|
357 |
+
last_index = indexs[n_first:]
|
358 |
+
elif method == "first_in_last_out":
|
359 |
+
first_index = indexs[n_last:]
|
360 |
+
last_index = indexs[:n_last]
|
361 |
+
elif method == "intertwine":
|
362 |
+
raise NotImplementedError
|
363 |
+
elif method == "random":
|
364 |
+
idx_ = torch.randperm(len(indexs))
|
365 |
+
first_index = indexs[idx_[:n_first]]
|
366 |
+
last_index = indexs[idx_[n_first:]]
|
367 |
+
return first_index, last_index
|
368 |
+
|
369 |
+
|
370 |
+
def split_tensor(
|
371 |
+
tensor: torch.Tensor,
|
372 |
+
dim: int,
|
373 |
+
n_first=None,
|
374 |
+
n_last=None,
|
375 |
+
method: Literal[
|
376 |
+
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
|
377 |
+
] = "first_in_first_out",
|
378 |
+
need_return_index: bool = False,
|
379 |
+
):
|
380 |
+
device = tensor.device
|
381 |
+
total = tensor.shape[dim]
|
382 |
+
if n_first is None:
|
383 |
+
n_first = total - n_last
|
384 |
+
if n_last is None:
|
385 |
+
n_last = total - n_first
|
386 |
+
indexs = torch.arange(
|
387 |
+
total,
|
388 |
+
dtype=torch.long,
|
389 |
+
device=device,
|
390 |
+
)
|
391 |
+
(
|
392 |
+
first_index,
|
393 |
+
last_index,
|
394 |
+
) = split_index(
|
395 |
+
indexs=indexs,
|
396 |
+
n_first=n_first,
|
397 |
+
method=method,
|
398 |
+
)
|
399 |
+
first_tensor = torch.index_select(tensor, dim=dim, index=first_index)
|
400 |
+
last_tensor = torch.index_select(tensor, dim=dim, index=last_index)
|
401 |
+
if need_return_index:
|
402 |
+
return (
|
403 |
+
first_tensor,
|
404 |
+
last_tensor,
|
405 |
+
first_index,
|
406 |
+
last_index,
|
407 |
+
)
|
408 |
+
else:
|
409 |
+
return (first_tensor, last_tensor)
|
410 |
+
|
411 |
+
|
412 |
+
# TODO: 待确定batch_index_select的优化
|
413 |
+
def batch_index_select(
|
414 |
+
tensor: torch.Tensor, index: torch.LongTensor, dim: int
|
415 |
+
) -> torch.Tensor:
|
416 |
+
"""_summary_
|
417 |
+
|
418 |
+
Args:
|
419 |
+
tensor (torch.Tensor): D1*D2*D3*D4...
|
420 |
+
index (torch.LongTensor): D1*N or N, N<= tensor.shape[dim]
|
421 |
+
dim (int): dim to select
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
torch.Tensor: D1*...*N*...
|
425 |
+
"""
|
426 |
+
# TODO: now only support N same for every d1
|
427 |
+
if len(index.shape) == 1:
|
428 |
+
return torch.index_select(tensor, dim=dim, index=index)
|
429 |
+
else:
|
430 |
+
index = repeat_index_to_target_size(index, tensor.shape[0])
|
431 |
+
out = []
|
432 |
+
for i in torch.arange(tensor.shape[0]):
|
433 |
+
sub_tensor = tensor[i]
|
434 |
+
sub_index = index[i]
|
435 |
+
d = torch.index_select(sub_tensor, dim=dim - 1, index=sub_index)
|
436 |
+
out.append(d)
|
437 |
+
return torch.stack(out).to(dtype=tensor.dtype)
|
438 |
+
|
439 |
+
|
440 |
+
def batch_index_copy(
|
441 |
+
tensor: torch.Tensor, dim: int, index: torch.LongTensor, source: torch.Tensor
|
442 |
+
) -> torch.Tensor:
|
443 |
+
"""_summary_
|
444 |
+
|
445 |
+
Args:
|
446 |
+
tensor (torch.Tensor): b*c*h
|
447 |
+
dim (int):
|
448 |
+
index (torch.LongTensor): b*d,
|
449 |
+
source (torch.Tensor):
|
450 |
+
b*d*h*..., if dim=1
|
451 |
+
b*c*d*..., if dim=2
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
torch.Tensor: b*c*d*...
|
455 |
+
"""
|
456 |
+
if len(index.shape) == 1:
|
457 |
+
tensor.index_copy_(dim=dim, index=index, source=source)
|
458 |
+
else:
|
459 |
+
index = repeat_index_to_target_size(index, tensor.shape[0])
|
460 |
+
|
461 |
+
batch_size = tensor.shape[0]
|
462 |
+
for b in torch.arange(batch_size):
|
463 |
+
sub_index = index[b]
|
464 |
+
sub_source = source[b]
|
465 |
+
sub_tensor = tensor[b]
|
466 |
+
sub_tensor.index_copy_(dim=dim - 1, index=sub_index, source=sub_source)
|
467 |
+
tensor[b] = sub_tensor
|
468 |
+
return tensor
|
469 |
+
|
470 |
+
|
471 |
+
def batch_index_fill(
|
472 |
+
tensor: torch.Tensor,
|
473 |
+
dim: int,
|
474 |
+
index: torch.LongTensor,
|
475 |
+
value: Literal[torch.Tensor, torch.float],
|
476 |
+
) -> torch.Tensor:
|
477 |
+
"""_summary_
|
478 |
+
|
479 |
+
Args:
|
480 |
+
tensor (torch.Tensor): b*c*h
|
481 |
+
dim (int):
|
482 |
+
index (torch.LongTensor): b*d,
|
483 |
+
value (torch.Tensor): b
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
torch.Tensor: b*c*d*...
|
487 |
+
"""
|
488 |
+
index = repeat_index_to_target_size(index, tensor.shape[0])
|
489 |
+
batch_size = tensor.shape[0]
|
490 |
+
for b in torch.arange(batch_size):
|
491 |
+
sub_index = index[b]
|
492 |
+
sub_value = value[b] if isinstance(value, torch.Tensor) else value
|
493 |
+
sub_tensor = tensor[b]
|
494 |
+
sub_tensor.index_fill_(dim - 1, sub_index, sub_value)
|
495 |
+
tensor[b] = sub_tensor
|
496 |
+
return tensor
|
497 |
+
|
498 |
+
|
499 |
+
def adaptive_instance_normalization(
|
500 |
+
src: torch.Tensor,
|
501 |
+
dst: torch.Tensor,
|
502 |
+
eps: float = 1e-6,
|
503 |
+
):
|
504 |
+
"""
|
505 |
+
Args:
|
506 |
+
src (torch.Tensor): b c t h w
|
507 |
+
dst (torch.Tensor): b c t h w
|
508 |
+
"""
|
509 |
+
ndim = src.ndim
|
510 |
+
if ndim == 5:
|
511 |
+
dim = (2, 3, 4)
|
512 |
+
elif ndim == 4:
|
513 |
+
dim = (2, 3)
|
514 |
+
elif ndim == 3:
|
515 |
+
dim = 2
|
516 |
+
else:
|
517 |
+
raise ValueError("only support ndim in [3,4,5], but given {ndim}")
|
518 |
+
var, mean = torch.var_mean(src, dim=dim, keepdim=True, correction=0)
|
519 |
+
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
520 |
+
dst = align_repeat_tensor_single_dim(dst, src.shape[0], dim=0)
|
521 |
+
mean_acc, var_acc = torch.var_mean(dst, dim=dim, keepdim=True, correction=0)
|
522 |
+
# mean_acc = sum(mean_acc) / float(len(mean_acc))
|
523 |
+
# var_acc = sum(var_acc) / float(len(var_acc))
|
524 |
+
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
525 |
+
src = (((src - mean) / std) * std_acc) + mean_acc
|
526 |
+
return src
|
527 |
+
|
528 |
+
|
529 |
+
def adaptive_instance_normalization_with_ref(
|
530 |
+
src: torch.LongTensor,
|
531 |
+
dst: torch.LongTensor,
|
532 |
+
style_fidelity: float = 0.5,
|
533 |
+
do_classifier_free_guidance: bool = True,
|
534 |
+
):
|
535 |
+
# logger.debug(
|
536 |
+
# f"src={src.shape}, min={src.min()}, max={src.max()}, mean={src.mean()}, \n"
|
537 |
+
# f"dst={src.shape}, min={dst.min()}, max={dst.max()}, mean={dst.mean()}"
|
538 |
+
# )
|
539 |
+
batch_size = src.shape[0] // 2
|
540 |
+
uc_mask = torch.Tensor([1] * batch_size + [0] * batch_size).type_as(src).bool()
|
541 |
+
src_uc = adaptive_instance_normalization(src, dst)
|
542 |
+
src_c = src_uc.clone()
|
543 |
+
# TODO: 该部分默认 do_classifier_free_guidance and style_fidelity > 0 = True
|
544 |
+
if do_classifier_free_guidance and style_fidelity > 0:
|
545 |
+
src_c[uc_mask] = src[uc_mask]
|
546 |
+
src = style_fidelity * src_c + (1.0 - style_fidelity) * src_uc
|
547 |
+
return src
|
548 |
+
|
549 |
+
|
550 |
+
def batch_adain_conditioned_tensor(
|
551 |
+
tensor: torch.Tensor,
|
552 |
+
src_index: torch.LongTensor,
|
553 |
+
dst_index: torch.LongTensor,
|
554 |
+
keep_dim: bool = True,
|
555 |
+
num_frames: int = None,
|
556 |
+
dim: int = 2,
|
557 |
+
style_fidelity: float = 0.5,
|
558 |
+
do_classifier_free_guidance: bool = True,
|
559 |
+
need_style_fidelity: bool = False,
|
560 |
+
):
|
561 |
+
"""_summary_
|
562 |
+
|
563 |
+
Args:
|
564 |
+
tensor (torch.Tensor): b c t h w
|
565 |
+
src_index (torch.LongTensor): _description_
|
566 |
+
dst_index (torch.LongTensor): _description_
|
567 |
+
keep_dim (bool, optional): _description_. Defaults to True.
|
568 |
+
|
569 |
+
Returns:
|
570 |
+
_type_: _description_
|
571 |
+
"""
|
572 |
+
ndim = tensor.ndim
|
573 |
+
dtype = tensor.dtype
|
574 |
+
if ndim == 4 and num_frames is not None:
|
575 |
+
tensor = rearrange(tensor, "(b t) c h w-> b c t h w ", t=num_frames)
|
576 |
+
src = batch_index_select(tensor, dim=dim, index=src_index).contiguous()
|
577 |
+
dst = batch_index_select(tensor, dim=dim, index=dst_index).contiguous()
|
578 |
+
if need_style_fidelity:
|
579 |
+
src = adaptive_instance_normalization_with_ref(
|
580 |
+
src=src,
|
581 |
+
dst=dst,
|
582 |
+
style_fidelity=style_fidelity,
|
583 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
584 |
+
need_style_fidelity=need_style_fidelity,
|
585 |
+
)
|
586 |
+
else:
|
587 |
+
src = adaptive_instance_normalization(
|
588 |
+
src=src,
|
589 |
+
dst=dst,
|
590 |
+
)
|
591 |
+
if keep_dim:
|
592 |
+
src = batch_concat_two_tensor_with_index(
|
593 |
+
src.to(dtype=dtype),
|
594 |
+
src_index,
|
595 |
+
dst.to(dtype=dtype),
|
596 |
+
dst_index,
|
597 |
+
dim=dim,
|
598 |
+
)
|
599 |
+
|
600 |
+
if ndim == 4 and num_frames is not None:
|
601 |
+
src = rearrange(tensor, "b c t h w ->(b t) c h w")
|
602 |
+
return src
|
603 |
+
|
604 |
+
|
605 |
+
def align_repeat_tensor_single_dim(
|
606 |
+
src: torch.Tensor,
|
607 |
+
target_length: int,
|
608 |
+
dim: int = 0,
|
609 |
+
n_src_base_length: int = 1,
|
610 |
+
src_base_index: List[int] = None,
|
611 |
+
) -> torch.Tensor:
|
612 |
+
"""沿着 dim 纬度, 补齐 src 的长度到目标 target_length。
|
613 |
+
当 src 长度不如 target_length 时, 取其中 前 n_src_base_length 然后 repeat 到 target_length
|
614 |
+
|
615 |
+
align length of src to target_length along dim
|
616 |
+
when src length is less than target_length, take the first n_src_base_length and repeat to target_length
|
617 |
+
|
618 |
+
Args:
|
619 |
+
src (torch.Tensor): 输入 tensor, input tensor
|
620 |
+
target_length (int): 目标长度, target_length
|
621 |
+
dim (int, optional): 处理纬度, target dim . Defaults to 0.
|
622 |
+
n_src_base_length (int, optional): src 的基本单元长度, basic length of src. Defaults to 1.
|
623 |
+
|
624 |
+
Returns:
|
625 |
+
torch.Tensor: _description_
|
626 |
+
"""
|
627 |
+
src_dim_length = src.shape[dim]
|
628 |
+
if target_length > src_dim_length:
|
629 |
+
if target_length % src_dim_length == 0:
|
630 |
+
new = src.repeat_interleave(
|
631 |
+
repeats=target_length // src_dim_length, dim=dim
|
632 |
+
)
|
633 |
+
else:
|
634 |
+
if src_base_index is None and n_src_base_length is not None:
|
635 |
+
src_base_index = torch.arange(n_src_base_length)
|
636 |
+
|
637 |
+
new = src.index_select(
|
638 |
+
dim=dim,
|
639 |
+
index=torch.LongTensor(src_base_index).to(device=src.device),
|
640 |
+
)
|
641 |
+
new = new.repeat_interleave(
|
642 |
+
repeats=target_length // len(src_base_index),
|
643 |
+
dim=dim,
|
644 |
+
)
|
645 |
+
elif target_length < src_dim_length:
|
646 |
+
new = src.index_select(
|
647 |
+
dim=dim,
|
648 |
+
index=torch.LongTensor(torch.arange(target_length)).to(device=src.device),
|
649 |
+
)
|
650 |
+
else:
|
651 |
+
new = src
|
652 |
+
return new
|
653 |
+
|
654 |
+
|
655 |
+
def fuse_part_tensor(
|
656 |
+
src: torch.Tensor,
|
657 |
+
dst: torch.Tensor,
|
658 |
+
overlap: int,
|
659 |
+
weight: float = 0.5,
|
660 |
+
skip_step: int = 0,
|
661 |
+
) -> torch.Tensor:
|
662 |
+
"""fuse overstep tensor with weight of src into dst
|
663 |
+
out = src_fused_part * weight + dst * (1-weight) for overlap
|
664 |
+
|
665 |
+
Args:
|
666 |
+
src (torch.Tensor): b c t h w
|
667 |
+
dst (torch.Tensor): b c t h w
|
668 |
+
overlap (int): 1
|
669 |
+
weight (float, optional): weight of src tensor part. Defaults to 0.5.
|
670 |
+
|
671 |
+
Returns:
|
672 |
+
torch.Tensor: fused tensor
|
673 |
+
"""
|
674 |
+
if overlap == 0:
|
675 |
+
return dst
|
676 |
+
else:
|
677 |
+
dst[:, :, skip_step : skip_step + overlap] = (
|
678 |
+
weight * src[:, :, -overlap:]
|
679 |
+
+ (1 - weight) * dst[:, :, skip_step : skip_step + overlap]
|
680 |
+
)
|
681 |
+
return dst
|
musev/logging.conf
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[loggers]
|
2 |
+
keys=root,musev
|
3 |
+
|
4 |
+
[handlers]
|
5 |
+
keys=consoleHandler
|
6 |
+
|
7 |
+
[formatters]
|
8 |
+
keys=musevFormatter
|
9 |
+
|
10 |
+
[logger_root]
|
11 |
+
level=INFO
|
12 |
+
handlers=consoleHandler
|
13 |
+
|
14 |
+
# logger level 尽量设置低一点
|
15 |
+
[logger_musev]
|
16 |
+
level=DEBUG
|
17 |
+
handlers=consoleHandler
|
18 |
+
qualname=musev
|
19 |
+
propagate=0
|
20 |
+
|
21 |
+
# handler level 设置比 logger level高
|
22 |
+
[handler_consoleHandler]
|
23 |
+
class=StreamHandler
|
24 |
+
level=DEBUG
|
25 |
+
# level=INFO
|
26 |
+
|
27 |
+
formatter=musevFormatter
|
28 |
+
args=(sys.stdout,)
|
29 |
+
|
30 |
+
[formatter_musevFormatter]
|
31 |
+
format=%(asctime)s- %(name)s:%(lineno)d- %(levelname)s- %(message)s
|
32 |
+
datefmt=
|
musev/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils.register import Register
|
2 |
+
|
3 |
+
Model_Register = Register(registry_name="torch_model")
|
musev/models/attention.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/64bf5d33b7ef1b1deac256bed7bd99b55020c4e0/src/diffusers/models/attention.py
|
16 |
+
from __future__ import annotations
|
17 |
+
from copy import deepcopy
|
18 |
+
|
19 |
+
from typing import Any, Dict, List, Literal, Optional, Callable, Tuple
|
20 |
+
import logging
|
21 |
+
from einops import rearrange
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from torch import nn
|
26 |
+
|
27 |
+
from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
|
28 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
29 |
+
from diffusers.models.attention_processor import Attention as DiffusersAttention
|
30 |
+
from diffusers.models.attention import (
|
31 |
+
BasicTransformerBlock as DiffusersBasicTransformerBlock,
|
32 |
+
AdaLayerNormZero,
|
33 |
+
AdaLayerNorm,
|
34 |
+
FeedForward,
|
35 |
+
)
|
36 |
+
from diffusers.models.attention_processor import AttnProcessor
|
37 |
+
|
38 |
+
from .attention_processor import IPAttention, BaseIPAttnProcessor
|
39 |
+
|
40 |
+
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
|
44 |
+
def not_use_xformers_anyway(
|
45 |
+
use_memory_efficient_attention_xformers: bool,
|
46 |
+
attention_op: Optional[Callable] = None,
|
47 |
+
):
|
48 |
+
return None
|
49 |
+
|
50 |
+
|
51 |
+
@maybe_allow_in_graph
|
52 |
+
class BasicTransformerBlock(DiffusersBasicTransformerBlock):
|
53 |
+
print_idx = 0
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
dim: int,
|
58 |
+
num_attention_heads: int,
|
59 |
+
attention_head_dim: int,
|
60 |
+
dropout=0,
|
61 |
+
cross_attention_dim: int | None = None,
|
62 |
+
activation_fn: str = "geglu",
|
63 |
+
num_embeds_ada_norm: int | None = None,
|
64 |
+
attention_bias: bool = False,
|
65 |
+
only_cross_attention: bool = False,
|
66 |
+
double_self_attention: bool = False,
|
67 |
+
upcast_attention: bool = False,
|
68 |
+
norm_elementwise_affine: bool = True,
|
69 |
+
norm_type: str = "layer_norm",
|
70 |
+
final_dropout: bool = False,
|
71 |
+
attention_type: str = "default",
|
72 |
+
allow_xformers: bool = True,
|
73 |
+
cross_attn_temporal_cond: bool = False,
|
74 |
+
image_scale: float = 1.0,
|
75 |
+
processor: AttnProcessor | None = None,
|
76 |
+
ip_adapter_cross_attn: bool = False,
|
77 |
+
need_t2i_facein: bool = False,
|
78 |
+
need_t2i_ip_adapter_face: bool = False,
|
79 |
+
):
|
80 |
+
if not only_cross_attention and double_self_attention:
|
81 |
+
cross_attention_dim = None
|
82 |
+
super().__init__(
|
83 |
+
dim,
|
84 |
+
num_attention_heads,
|
85 |
+
attention_head_dim,
|
86 |
+
dropout,
|
87 |
+
cross_attention_dim,
|
88 |
+
activation_fn,
|
89 |
+
num_embeds_ada_norm,
|
90 |
+
attention_bias,
|
91 |
+
only_cross_attention,
|
92 |
+
double_self_attention,
|
93 |
+
upcast_attention,
|
94 |
+
norm_elementwise_affine,
|
95 |
+
norm_type,
|
96 |
+
final_dropout,
|
97 |
+
attention_type,
|
98 |
+
)
|
99 |
+
|
100 |
+
self.attn1 = IPAttention(
|
101 |
+
query_dim=dim,
|
102 |
+
heads=num_attention_heads,
|
103 |
+
dim_head=attention_head_dim,
|
104 |
+
dropout=dropout,
|
105 |
+
bias=attention_bias,
|
106 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
107 |
+
upcast_attention=upcast_attention,
|
108 |
+
cross_attn_temporal_cond=cross_attn_temporal_cond,
|
109 |
+
image_scale=image_scale,
|
110 |
+
ip_adapter_dim=cross_attention_dim
|
111 |
+
if only_cross_attention
|
112 |
+
else attention_head_dim,
|
113 |
+
facein_dim=cross_attention_dim
|
114 |
+
if only_cross_attention
|
115 |
+
else attention_head_dim,
|
116 |
+
processor=processor,
|
117 |
+
)
|
118 |
+
# 2. Cross-Attn
|
119 |
+
if cross_attention_dim is not None or double_self_attention:
|
120 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
121 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
122 |
+
# the second cross attention block.
|
123 |
+
self.norm2 = (
|
124 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
125 |
+
if self.use_ada_layer_norm
|
126 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
127 |
+
)
|
128 |
+
|
129 |
+
self.attn2 = IPAttention(
|
130 |
+
query_dim=dim,
|
131 |
+
cross_attention_dim=cross_attention_dim
|
132 |
+
if not double_self_attention
|
133 |
+
else None,
|
134 |
+
heads=num_attention_heads,
|
135 |
+
dim_head=attention_head_dim,
|
136 |
+
dropout=dropout,
|
137 |
+
bias=attention_bias,
|
138 |
+
upcast_attention=upcast_attention,
|
139 |
+
cross_attn_temporal_cond=ip_adapter_cross_attn,
|
140 |
+
need_t2i_facein=need_t2i_facein,
|
141 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
142 |
+
image_scale=image_scale,
|
143 |
+
ip_adapter_dim=cross_attention_dim
|
144 |
+
if not double_self_attention
|
145 |
+
else attention_head_dim,
|
146 |
+
facein_dim=cross_attention_dim
|
147 |
+
if not double_self_attention
|
148 |
+
else attention_head_dim,
|
149 |
+
ip_adapter_face_dim=cross_attention_dim
|
150 |
+
if not double_self_attention
|
151 |
+
else attention_head_dim,
|
152 |
+
processor=processor,
|
153 |
+
) # is self-attn if encoder_hidden_states is none
|
154 |
+
else:
|
155 |
+
self.norm2 = None
|
156 |
+
self.attn2 = None
|
157 |
+
if self.attn1 is not None:
|
158 |
+
if not allow_xformers:
|
159 |
+
self.attn1.set_use_memory_efficient_attention_xformers = (
|
160 |
+
not_use_xformers_anyway
|
161 |
+
)
|
162 |
+
if self.attn2 is not None:
|
163 |
+
if not allow_xformers:
|
164 |
+
self.attn2.set_use_memory_efficient_attention_xformers = (
|
165 |
+
not_use_xformers_anyway
|
166 |
+
)
|
167 |
+
self.double_self_attention = double_self_attention
|
168 |
+
self.only_cross_attention = only_cross_attention
|
169 |
+
self.cross_attn_temporal_cond = cross_attn_temporal_cond
|
170 |
+
self.image_scale = image_scale
|
171 |
+
|
172 |
+
def forward(
|
173 |
+
self,
|
174 |
+
hidden_states: torch.FloatTensor,
|
175 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
176 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
177 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
178 |
+
timestep: Optional[torch.LongTensor] = None,
|
179 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
180 |
+
class_labels: Optional[torch.LongTensor] = None,
|
181 |
+
self_attn_block_embs: Optional[Tuple[List[torch.Tensor], List[None]]] = None,
|
182 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
183 |
+
) -> torch.FloatTensor:
|
184 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
185 |
+
# 0. Self-Attention
|
186 |
+
if self.use_ada_layer_norm:
|
187 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
188 |
+
elif self.use_ada_layer_norm_zero:
|
189 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
190 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
norm_hidden_states = self.norm1(hidden_states)
|
194 |
+
|
195 |
+
# 1. Retrieve lora scale.
|
196 |
+
lora_scale = (
|
197 |
+
cross_attention_kwargs.get("scale", 1.0)
|
198 |
+
if cross_attention_kwargs is not None
|
199 |
+
else 1.0
|
200 |
+
)
|
201 |
+
|
202 |
+
if cross_attention_kwargs is None:
|
203 |
+
cross_attention_kwargs = {}
|
204 |
+
# 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备
|
205 |
+
# special AttnProcessor needs input parameters in cross_attention_kwargs
|
206 |
+
original_cross_attention_kwargs = {
|
207 |
+
k: v
|
208 |
+
for k, v in cross_attention_kwargs.items()
|
209 |
+
if k
|
210 |
+
not in [
|
211 |
+
"num_frames",
|
212 |
+
"sample_index",
|
213 |
+
"vision_conditon_frames_sample_index",
|
214 |
+
"vision_cond",
|
215 |
+
"vision_clip_emb",
|
216 |
+
"ip_adapter_scale",
|
217 |
+
"face_emb",
|
218 |
+
"facein_scale",
|
219 |
+
"ip_adapter_face_emb",
|
220 |
+
"ip_adapter_face_scale",
|
221 |
+
"do_classifier_free_guidance",
|
222 |
+
]
|
223 |
+
}
|
224 |
+
|
225 |
+
if "do_classifier_free_guidance" in cross_attention_kwargs:
|
226 |
+
do_classifier_free_guidance = cross_attention_kwargs[
|
227 |
+
"do_classifier_free_guidance"
|
228 |
+
]
|
229 |
+
else:
|
230 |
+
do_classifier_free_guidance = False
|
231 |
+
|
232 |
+
# 2. Prepare GLIGEN inputs
|
233 |
+
original_cross_attention_kwargs = (
|
234 |
+
original_cross_attention_kwargs.copy()
|
235 |
+
if original_cross_attention_kwargs is not None
|
236 |
+
else {}
|
237 |
+
)
|
238 |
+
gligen_kwargs = original_cross_attention_kwargs.pop("gligen", None)
|
239 |
+
|
240 |
+
# 返回self_attn的结果,适用于referencenet的输出给其他Unet来使用
|
241 |
+
# return the result of self_attn, which is suitable for the output of referencenet to be used by other Unet
|
242 |
+
if (
|
243 |
+
self_attn_block_embs is not None
|
244 |
+
and self_attn_block_embs_mode.lower() == "write"
|
245 |
+
):
|
246 |
+
# self_attn_block_emb = self.attn1.head_to_batch_dim(attn_output, out_dim=4)
|
247 |
+
self_attn_block_emb = norm_hidden_states
|
248 |
+
if not hasattr(self, "spatial_self_attn_idx"):
|
249 |
+
raise ValueError(
|
250 |
+
"must call unet.insert_spatial_self_attn_idx to generate spatial attn index"
|
251 |
+
)
|
252 |
+
basick_transformer_idx = self.spatial_self_attn_idx
|
253 |
+
if self.print_idx == 0:
|
254 |
+
logger.debug(
|
255 |
+
f"self_attn_block_embs, self_attn_block_embs_mode={self_attn_block_embs_mode}, "
|
256 |
+
f"basick_transformer_idx={basick_transformer_idx}, length={len(self_attn_block_embs)}, shape={self_attn_block_emb.shape}, "
|
257 |
+
# f"attn1 processor, {type(self.attn1.processor)}"
|
258 |
+
)
|
259 |
+
self_attn_block_embs[basick_transformer_idx] = self_attn_block_emb
|
260 |
+
|
261 |
+
# read and put referencenet emb into cross_attention_kwargs, which would be fused into attn_processor
|
262 |
+
if (
|
263 |
+
self_attn_block_embs is not None
|
264 |
+
and self_attn_block_embs_mode.lower() == "read"
|
265 |
+
):
|
266 |
+
basick_transformer_idx = self.spatial_self_attn_idx
|
267 |
+
if not hasattr(self, "spatial_self_attn_idx"):
|
268 |
+
raise ValueError(
|
269 |
+
"must call unet.insert_spatial_self_attn_idx to generate spatial attn index"
|
270 |
+
)
|
271 |
+
if self.print_idx == 0:
|
272 |
+
logger.debug(
|
273 |
+
f"refer_self_attn_emb: , self_attn_block_embs_mode={self_attn_block_embs_mode}, "
|
274 |
+
f"length={len(self_attn_block_embs)}, idx={basick_transformer_idx}, "
|
275 |
+
# f"attn1 processor, {type(self.attn1.processor)}, "
|
276 |
+
)
|
277 |
+
ref_emb = self_attn_block_embs[basick_transformer_idx]
|
278 |
+
cross_attention_kwargs["refer_emb"] = ref_emb
|
279 |
+
if self.print_idx == 0:
|
280 |
+
logger.debug(
|
281 |
+
f"unet attention read, {self.spatial_self_attn_idx}",
|
282 |
+
)
|
283 |
+
# ------------------------------warning-----------------------
|
284 |
+
# 这两行由于使用了ref_emb会导致和checkpoint_train相关的训练错误,具体未知,留在这里作为警示
|
285 |
+
# bellow annoated code will cause training error, keep it here as a warning
|
286 |
+
# logger.debug(f"ref_emb shape,{ref_emb.shape}, {ref_emb.mean()}")
|
287 |
+
# logger.debug(
|
288 |
+
# f"norm_hidden_states shape, {norm_hidden_states.shape}, {norm_hidden_states.mean()}",
|
289 |
+
# )
|
290 |
+
if self.attn1 is None:
|
291 |
+
self.print_idx += 1
|
292 |
+
return norm_hidden_states
|
293 |
+
attn_output = self.attn1(
|
294 |
+
norm_hidden_states,
|
295 |
+
encoder_hidden_states=encoder_hidden_states
|
296 |
+
if self.only_cross_attention
|
297 |
+
else None,
|
298 |
+
attention_mask=attention_mask,
|
299 |
+
**(
|
300 |
+
cross_attention_kwargs
|
301 |
+
if isinstance(self.attn1.processor, BaseIPAttnProcessor)
|
302 |
+
else original_cross_attention_kwargs
|
303 |
+
),
|
304 |
+
)
|
305 |
+
|
306 |
+
if self.use_ada_layer_norm_zero:
|
307 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
308 |
+
hidden_states = attn_output + hidden_states
|
309 |
+
|
310 |
+
# 推断的时候,对于uncondition_部分独立生成,排除掉 refer_emb,
|
311 |
+
# 首帧等的影响,避免生成参考了refer_emb、首帧等,又在uncond上去除了
|
312 |
+
# in inference stage, eliminate influence of refer_emb, vis_cond on unconditionpart
|
313 |
+
# to avoid use that, and then eliminate in pipeline
|
314 |
+
# refer to moore-animate anyone
|
315 |
+
|
316 |
+
# do_classifier_free_guidance = False
|
317 |
+
if self.print_idx == 0:
|
318 |
+
logger.debug(f"do_classifier_free_guidance={do_classifier_free_guidance},")
|
319 |
+
if do_classifier_free_guidance:
|
320 |
+
hidden_states_c = attn_output.clone()
|
321 |
+
_uc_mask = (
|
322 |
+
torch.Tensor(
|
323 |
+
[1] * (norm_hidden_states.shape[0] // 2)
|
324 |
+
+ [0] * (norm_hidden_states.shape[0] // 2)
|
325 |
+
)
|
326 |
+
.to(norm_hidden_states.device)
|
327 |
+
.bool()
|
328 |
+
)
|
329 |
+
hidden_states_c[_uc_mask] = self.attn1(
|
330 |
+
norm_hidden_states[_uc_mask],
|
331 |
+
encoder_hidden_states=norm_hidden_states[_uc_mask],
|
332 |
+
attention_mask=attention_mask,
|
333 |
+
)
|
334 |
+
attn_output = hidden_states_c.clone()
|
335 |
+
|
336 |
+
if "refer_emb" in cross_attention_kwargs:
|
337 |
+
del cross_attention_kwargs["refer_emb"]
|
338 |
+
|
339 |
+
# 2.5 GLIGEN Control
|
340 |
+
if gligen_kwargs is not None:
|
341 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
342 |
+
# 2.5 ends
|
343 |
+
|
344 |
+
# 3. Cross-Attention
|
345 |
+
if self.attn2 is not None:
|
346 |
+
norm_hidden_states = (
|
347 |
+
self.norm2(hidden_states, timestep)
|
348 |
+
if self.use_ada_layer_norm
|
349 |
+
else self.norm2(hidden_states)
|
350 |
+
)
|
351 |
+
|
352 |
+
# 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备
|
353 |
+
# special AttnProcessor needs input parameters in cross_attention_kwargs
|
354 |
+
attn_output = self.attn2(
|
355 |
+
norm_hidden_states,
|
356 |
+
encoder_hidden_states=encoder_hidden_states
|
357 |
+
if not self.double_self_attention
|
358 |
+
else None,
|
359 |
+
attention_mask=encoder_attention_mask,
|
360 |
+
**(
|
361 |
+
original_cross_attention_kwargs
|
362 |
+
if not isinstance(self.attn2.processor, BaseIPAttnProcessor)
|
363 |
+
else cross_attention_kwargs
|
364 |
+
),
|
365 |
+
)
|
366 |
+
if self.print_idx == 0:
|
367 |
+
logger.debug(
|
368 |
+
f"encoder_hidden_states, type={type(encoder_hidden_states)}"
|
369 |
+
)
|
370 |
+
if encoder_hidden_states is not None:
|
371 |
+
logger.debug(
|
372 |
+
f"encoder_hidden_states, ={encoder_hidden_states.shape}"
|
373 |
+
)
|
374 |
+
|
375 |
+
# encoder_hidden_states_tmp = (
|
376 |
+
# encoder_hidden_states
|
377 |
+
# if not self.double_self_attention
|
378 |
+
# else norm_hidden_states
|
379 |
+
# )
|
380 |
+
# if do_classifier_free_guidance:
|
381 |
+
# hidden_states_c = attn_output.clone()
|
382 |
+
# _uc_mask = (
|
383 |
+
# torch.Tensor(
|
384 |
+
# [1] * (norm_hidden_states.shape[0] // 2)
|
385 |
+
# + [0] * (norm_hidden_states.shape[0] // 2)
|
386 |
+
# )
|
387 |
+
# .to(norm_hidden_states.device)
|
388 |
+
# .bool()
|
389 |
+
# )
|
390 |
+
# hidden_states_c[_uc_mask] = self.attn2(
|
391 |
+
# norm_hidden_states[_uc_mask],
|
392 |
+
# encoder_hidden_states=encoder_hidden_states_tmp[_uc_mask],
|
393 |
+
# attention_mask=attention_mask,
|
394 |
+
# )
|
395 |
+
# attn_output = hidden_states_c.clone()
|
396 |
+
hidden_states = attn_output + hidden_states
|
397 |
+
# 4. Feed-forward
|
398 |
+
if self.norm3 is not None and self.ff is not None:
|
399 |
+
norm_hidden_states = self.norm3(hidden_states)
|
400 |
+
if self.use_ada_layer_norm_zero:
|
401 |
+
norm_hidden_states = (
|
402 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
403 |
+
)
|
404 |
+
if self._chunk_size is not None:
|
405 |
+
# "feed_forward_chunk_size" can be used to save memory
|
406 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
407 |
+
raise ValueError(
|
408 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
409 |
+
)
|
410 |
+
|
411 |
+
num_chunks = (
|
412 |
+
norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
413 |
+
)
|
414 |
+
ff_output = torch.cat(
|
415 |
+
[
|
416 |
+
self.ff(hid_slice, scale=lora_scale)
|
417 |
+
for hid_slice in norm_hidden_states.chunk(
|
418 |
+
num_chunks, dim=self._chunk_dim
|
419 |
+
)
|
420 |
+
],
|
421 |
+
dim=self._chunk_dim,
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
425 |
+
|
426 |
+
if self.use_ada_layer_norm_zero:
|
427 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
428 |
+
|
429 |
+
hidden_states = ff_output + hidden_states
|
430 |
+
self.print_idx += 1
|
431 |
+
return hidden_states
|
musev/models/attention_processor.py
ADDED
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""该模型是自定义的attn_processor,实现特殊功能的 Attn功能。
|
16 |
+
相对而言,开源代码经常会重新定义Attention 类,
|
17 |
+
|
18 |
+
This module implements special AttnProcessor function with custom attn_processor class.
|
19 |
+
While other open source code always modify Attention class.
|
20 |
+
"""
|
21 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
22 |
+
from __future__ import annotations
|
23 |
+
|
24 |
+
import time
|
25 |
+
from typing import Any, Callable, Optional
|
26 |
+
import logging
|
27 |
+
|
28 |
+
from einops import rearrange, repeat
|
29 |
+
import torch
|
30 |
+
import torch.nn as nn
|
31 |
+
import torch.nn.functional as F
|
32 |
+
import xformers
|
33 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
34 |
+
|
35 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
36 |
+
from diffusers.models.attention_processor import (
|
37 |
+
Attention as DiffusersAttention,
|
38 |
+
AttnProcessor,
|
39 |
+
AttnProcessor2_0,
|
40 |
+
)
|
41 |
+
from ..data.data_util import (
|
42 |
+
batch_concat_two_tensor_with_index,
|
43 |
+
batch_index_select,
|
44 |
+
align_repeat_tensor_single_dim,
|
45 |
+
batch_adain_conditioned_tensor,
|
46 |
+
)
|
47 |
+
|
48 |
+
from . import Model_Register
|
49 |
+
|
50 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
51 |
+
|
52 |
+
|
53 |
+
@maybe_allow_in_graph
|
54 |
+
class IPAttention(DiffusersAttention):
|
55 |
+
r"""
|
56 |
+
Modified Attention class which has special layer, like ip_apadapter_to_k, ip_apadapter_to_v,
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
query_dim: int,
|
62 |
+
cross_attention_dim: int | None = None,
|
63 |
+
heads: int = 8,
|
64 |
+
dim_head: int = 64,
|
65 |
+
dropout: float = 0,
|
66 |
+
bias=False,
|
67 |
+
upcast_attention: bool = False,
|
68 |
+
upcast_softmax: bool = False,
|
69 |
+
cross_attention_norm: str | None = None,
|
70 |
+
cross_attention_norm_num_groups: int = 32,
|
71 |
+
added_kv_proj_dim: int | None = None,
|
72 |
+
norm_num_groups: int | None = None,
|
73 |
+
spatial_norm_dim: int | None = None,
|
74 |
+
out_bias: bool = True,
|
75 |
+
scale_qk: bool = True,
|
76 |
+
only_cross_attention: bool = False,
|
77 |
+
eps: float = 0.00001,
|
78 |
+
rescale_output_factor: float = 1,
|
79 |
+
residual_connection: bool = False,
|
80 |
+
_from_deprecated_attn_block=False,
|
81 |
+
processor: AttnProcessor | None = None,
|
82 |
+
cross_attn_temporal_cond: bool = False,
|
83 |
+
image_scale: float = 1.0,
|
84 |
+
ip_adapter_dim: int = None,
|
85 |
+
need_t2i_facein: bool = False,
|
86 |
+
facein_dim: int = None,
|
87 |
+
need_t2i_ip_adapter_face: bool = False,
|
88 |
+
ip_adapter_face_dim: int = None,
|
89 |
+
):
|
90 |
+
super().__init__(
|
91 |
+
query_dim,
|
92 |
+
cross_attention_dim,
|
93 |
+
heads,
|
94 |
+
dim_head,
|
95 |
+
dropout,
|
96 |
+
bias,
|
97 |
+
upcast_attention,
|
98 |
+
upcast_softmax,
|
99 |
+
cross_attention_norm,
|
100 |
+
cross_attention_norm_num_groups,
|
101 |
+
added_kv_proj_dim,
|
102 |
+
norm_num_groups,
|
103 |
+
spatial_norm_dim,
|
104 |
+
out_bias,
|
105 |
+
scale_qk,
|
106 |
+
only_cross_attention,
|
107 |
+
eps,
|
108 |
+
rescale_output_factor,
|
109 |
+
residual_connection,
|
110 |
+
_from_deprecated_attn_block,
|
111 |
+
processor,
|
112 |
+
)
|
113 |
+
self.cross_attn_temporal_cond = cross_attn_temporal_cond
|
114 |
+
self.image_scale = image_scale
|
115 |
+
# 面向首帧的 ip_adapter
|
116 |
+
# ip_apdater
|
117 |
+
if cross_attn_temporal_cond:
|
118 |
+
self.to_k_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
|
119 |
+
self.to_v_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
|
120 |
+
# facein
|
121 |
+
self.need_t2i_facein = need_t2i_facein
|
122 |
+
self.facein_dim = facein_dim
|
123 |
+
if need_t2i_facein:
|
124 |
+
raise NotImplementedError("facein")
|
125 |
+
|
126 |
+
# ip_adapter_face
|
127 |
+
self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
|
128 |
+
self.ip_adapter_face_dim = ip_adapter_face_dim
|
129 |
+
if need_t2i_ip_adapter_face:
|
130 |
+
self.ip_adapter_face_to_k_ip = LoRACompatibleLinear(
|
131 |
+
ip_adapter_face_dim, query_dim, bias=False
|
132 |
+
)
|
133 |
+
self.ip_adapter_face_to_v_ip = LoRACompatibleLinear(
|
134 |
+
ip_adapter_face_dim, query_dim, bias=False
|
135 |
+
)
|
136 |
+
|
137 |
+
def set_use_memory_efficient_attention_xformers(
|
138 |
+
self,
|
139 |
+
use_memory_efficient_attention_xformers: bool,
|
140 |
+
attention_op: Callable[..., Any] | None = None,
|
141 |
+
):
|
142 |
+
if (
|
143 |
+
"XFormers" in self.processor.__class__.__name__
|
144 |
+
or "IP" in self.processor.__class__.__name__
|
145 |
+
):
|
146 |
+
pass
|
147 |
+
else:
|
148 |
+
return super().set_use_memory_efficient_attention_xformers(
|
149 |
+
use_memory_efficient_attention_xformers, attention_op
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
@Model_Register.register
|
154 |
+
class BaseIPAttnProcessor(nn.Module):
|
155 |
+
print_idx = 0
|
156 |
+
|
157 |
+
def __init__(self, *args, **kwargs) -> None:
|
158 |
+
super().__init__(*args, **kwargs)
|
159 |
+
|
160 |
+
|
161 |
+
@Model_Register.register
|
162 |
+
class T2IReferencenetIPAdapterXFormersAttnProcessor(BaseIPAttnProcessor):
|
163 |
+
r"""
|
164 |
+
面向 ref_image的 self_attn的 IPAdapter
|
165 |
+
"""
|
166 |
+
print_idx = 0
|
167 |
+
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
attention_op: Optional[Callable] = None,
|
171 |
+
):
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
self.attention_op = attention_op
|
175 |
+
|
176 |
+
def __call__(
|
177 |
+
self,
|
178 |
+
attn: IPAttention,
|
179 |
+
hidden_states: torch.FloatTensor,
|
180 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
181 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
182 |
+
temb: Optional[torch.FloatTensor] = None,
|
183 |
+
scale: float = 1.0,
|
184 |
+
num_frames: int = None,
|
185 |
+
sample_index: torch.LongTensor = None,
|
186 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
187 |
+
refer_emb: torch.Tensor = None,
|
188 |
+
vision_clip_emb: torch.Tensor = None,
|
189 |
+
ip_adapter_scale: float = 1.0,
|
190 |
+
face_emb: torch.Tensor = None,
|
191 |
+
facein_scale: float = 1.0,
|
192 |
+
ip_adapter_face_emb: torch.Tensor = None,
|
193 |
+
ip_adapter_face_scale: float = 1.0,
|
194 |
+
do_classifier_free_guidance: bool = False,
|
195 |
+
):
|
196 |
+
residual = hidden_states
|
197 |
+
|
198 |
+
if attn.spatial_norm is not None:
|
199 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
200 |
+
|
201 |
+
input_ndim = hidden_states.ndim
|
202 |
+
|
203 |
+
if input_ndim == 4:
|
204 |
+
batch_size, channel, height, width = hidden_states.shape
|
205 |
+
hidden_states = hidden_states.view(
|
206 |
+
batch_size, channel, height * width
|
207 |
+
).transpose(1, 2)
|
208 |
+
|
209 |
+
batch_size, key_tokens, _ = (
|
210 |
+
hidden_states.shape
|
211 |
+
if encoder_hidden_states is None
|
212 |
+
else encoder_hidden_states.shape
|
213 |
+
)
|
214 |
+
|
215 |
+
attention_mask = attn.prepare_attention_mask(
|
216 |
+
attention_mask, key_tokens, batch_size
|
217 |
+
)
|
218 |
+
if attention_mask is not None:
|
219 |
+
# expand our mask's singleton query_tokens dimension:
|
220 |
+
# [batch*heads, 1, key_tokens] ->
|
221 |
+
# [batch*heads, query_tokens, key_tokens]
|
222 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
223 |
+
# [batch*heads, query_tokens, key_tokens]
|
224 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
225 |
+
_, query_tokens, _ = hidden_states.shape
|
226 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
227 |
+
|
228 |
+
if attn.group_norm is not None:
|
229 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
230 |
+
1, 2
|
231 |
+
)
|
232 |
+
|
233 |
+
query = attn.to_q(hidden_states, scale=scale)
|
234 |
+
|
235 |
+
if encoder_hidden_states is None:
|
236 |
+
encoder_hidden_states = hidden_states
|
237 |
+
elif attn.norm_cross:
|
238 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
239 |
+
encoder_hidden_states
|
240 |
+
)
|
241 |
+
encoder_hidden_states = align_repeat_tensor_single_dim(
|
242 |
+
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
|
243 |
+
)
|
244 |
+
key = attn.to_k(encoder_hidden_states, scale=scale)
|
245 |
+
value = attn.to_v(encoder_hidden_states, scale=scale)
|
246 |
+
|
247 |
+
# for facein
|
248 |
+
if self.print_idx == 0:
|
249 |
+
logger.debug(
|
250 |
+
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(face_emb)={type(face_emb)}, facein_scale={facein_scale}"
|
251 |
+
)
|
252 |
+
if facein_scale > 0 and face_emb is not None:
|
253 |
+
raise NotImplementedError("facein")
|
254 |
+
|
255 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
256 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
257 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
258 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
259 |
+
query,
|
260 |
+
key,
|
261 |
+
value,
|
262 |
+
attn_bias=attention_mask,
|
263 |
+
op=self.attention_op,
|
264 |
+
scale=attn.scale,
|
265 |
+
)
|
266 |
+
|
267 |
+
# ip-adapter start
|
268 |
+
if self.print_idx == 0:
|
269 |
+
logger.debug(
|
270 |
+
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(vision_clip_emb)={type(vision_clip_emb)}"
|
271 |
+
)
|
272 |
+
if ip_adapter_scale > 0 and vision_clip_emb is not None:
|
273 |
+
if self.print_idx == 0:
|
274 |
+
logger.debug(
|
275 |
+
f"T2I cross_attn, ipadapter, vision_clip_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
|
276 |
+
)
|
277 |
+
ip_key = attn.to_k_ip(vision_clip_emb)
|
278 |
+
ip_value = attn.to_v_ip(vision_clip_emb)
|
279 |
+
ip_key = align_repeat_tensor_single_dim(
|
280 |
+
ip_key, target_length=batch_size, dim=0
|
281 |
+
)
|
282 |
+
ip_value = align_repeat_tensor_single_dim(
|
283 |
+
ip_value, target_length=batch_size, dim=0
|
284 |
+
)
|
285 |
+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
286 |
+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
287 |
+
if self.print_idx == 0:
|
288 |
+
logger.debug(
|
289 |
+
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
|
290 |
+
)
|
291 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
292 |
+
hidden_states_from_ip = xformers.ops.memory_efficient_attention(
|
293 |
+
query,
|
294 |
+
ip_key,
|
295 |
+
ip_value,
|
296 |
+
attn_bias=attention_mask,
|
297 |
+
op=self.attention_op,
|
298 |
+
scale=attn.scale,
|
299 |
+
)
|
300 |
+
hidden_states = hidden_states + ip_adapter_scale * hidden_states_from_ip
|
301 |
+
# ip-adapter end
|
302 |
+
|
303 |
+
# ip-adapter face start
|
304 |
+
if self.print_idx == 0:
|
305 |
+
logger.debug(
|
306 |
+
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(ip_adapter_face_emb)={type(ip_adapter_face_emb)}"
|
307 |
+
)
|
308 |
+
if ip_adapter_face_scale > 0 and ip_adapter_face_emb is not None:
|
309 |
+
if self.print_idx == 0:
|
310 |
+
logger.debug(
|
311 |
+
f"T2I cross_attn, ipadapter face, ip_adapter_face_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
|
312 |
+
)
|
313 |
+
ip_key = attn.ip_adapter_face_to_k_ip(ip_adapter_face_emb)
|
314 |
+
ip_value = attn.ip_adapter_face_to_v_ip(ip_adapter_face_emb)
|
315 |
+
ip_key = align_repeat_tensor_single_dim(
|
316 |
+
ip_key, target_length=batch_size, dim=0
|
317 |
+
)
|
318 |
+
ip_value = align_repeat_tensor_single_dim(
|
319 |
+
ip_value, target_length=batch_size, dim=0
|
320 |
+
)
|
321 |
+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
322 |
+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
323 |
+
if self.print_idx == 0:
|
324 |
+
logger.debug(
|
325 |
+
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
|
326 |
+
)
|
327 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
328 |
+
hidden_states_from_ip = xformers.ops.memory_efficient_attention(
|
329 |
+
query,
|
330 |
+
ip_key,
|
331 |
+
ip_value,
|
332 |
+
attn_bias=attention_mask,
|
333 |
+
op=self.attention_op,
|
334 |
+
scale=attn.scale,
|
335 |
+
)
|
336 |
+
hidden_states = (
|
337 |
+
hidden_states + ip_adapter_face_scale * hidden_states_from_ip
|
338 |
+
)
|
339 |
+
# ip-adapter face end
|
340 |
+
|
341 |
+
hidden_states = hidden_states.to(query.dtype)
|
342 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
343 |
+
|
344 |
+
# linear proj
|
345 |
+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
346 |
+
# dropout
|
347 |
+
hidden_states = attn.to_out[1](hidden_states)
|
348 |
+
|
349 |
+
if input_ndim == 4:
|
350 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
351 |
+
batch_size, channel, height, width
|
352 |
+
)
|
353 |
+
|
354 |
+
if attn.residual_connection:
|
355 |
+
hidden_states = hidden_states + residual
|
356 |
+
|
357 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
358 |
+
self.print_idx += 1
|
359 |
+
return hidden_states
|
360 |
+
|
361 |
+
|
362 |
+
@Model_Register.register
|
363 |
+
class NonParamT2ISelfReferenceXFormersAttnProcessor(BaseIPAttnProcessor):
|
364 |
+
r"""
|
365 |
+
面向首帧的 referenceonly attn,适用于 T2I的 self_attn
|
366 |
+
referenceonly with vis_cond as key, value, in t2i self_attn.
|
367 |
+
"""
|
368 |
+
print_idx = 0
|
369 |
+
|
370 |
+
def __init__(
|
371 |
+
self,
|
372 |
+
attention_op: Optional[Callable] = None,
|
373 |
+
):
|
374 |
+
super().__init__()
|
375 |
+
|
376 |
+
self.attention_op = attention_op
|
377 |
+
|
378 |
+
def __call__(
|
379 |
+
self,
|
380 |
+
attn: IPAttention,
|
381 |
+
hidden_states: torch.FloatTensor,
|
382 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
383 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
384 |
+
temb: Optional[torch.FloatTensor] = None,
|
385 |
+
scale: float = 1.0,
|
386 |
+
num_frames: int = None,
|
387 |
+
sample_index: torch.LongTensor = None,
|
388 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
389 |
+
refer_emb: torch.Tensor = None,
|
390 |
+
face_emb: torch.Tensor = None,
|
391 |
+
vision_clip_emb: torch.Tensor = None,
|
392 |
+
ip_adapter_scale: float = 1.0,
|
393 |
+
facein_scale: float = 1.0,
|
394 |
+
ip_adapter_face_emb: torch.Tensor = None,
|
395 |
+
ip_adapter_face_scale: float = 1.0,
|
396 |
+
do_classifier_free_guidance: bool = False,
|
397 |
+
):
|
398 |
+
residual = hidden_states
|
399 |
+
|
400 |
+
if attn.spatial_norm is not None:
|
401 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
402 |
+
|
403 |
+
input_ndim = hidden_states.ndim
|
404 |
+
|
405 |
+
if input_ndim == 4:
|
406 |
+
batch_size, channel, height, width = hidden_states.shape
|
407 |
+
hidden_states = hidden_states.view(
|
408 |
+
batch_size, channel, height * width
|
409 |
+
).transpose(1, 2)
|
410 |
+
|
411 |
+
batch_size, key_tokens, _ = (
|
412 |
+
hidden_states.shape
|
413 |
+
if encoder_hidden_states is None
|
414 |
+
else encoder_hidden_states.shape
|
415 |
+
)
|
416 |
+
|
417 |
+
attention_mask = attn.prepare_attention_mask(
|
418 |
+
attention_mask, key_tokens, batch_size
|
419 |
+
)
|
420 |
+
if attention_mask is not None:
|
421 |
+
# expand our mask's singleton query_tokens dimension:
|
422 |
+
# [batch*heads, 1, key_tokens] ->
|
423 |
+
# [batch*heads, query_tokens, key_tokens]
|
424 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
425 |
+
# [batch*heads, query_tokens, key_tokens]
|
426 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
427 |
+
_, query_tokens, _ = hidden_states.shape
|
428 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
429 |
+
|
430 |
+
# vision_cond in same unet attn start
|
431 |
+
if (
|
432 |
+
vision_conditon_frames_sample_index is not None and num_frames > 1
|
433 |
+
) or refer_emb is not None:
|
434 |
+
batchsize_timesize = hidden_states.shape[0]
|
435 |
+
if self.print_idx == 0:
|
436 |
+
logger.debug(
|
437 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor 0, hidden_states={hidden_states.shape}, vision_conditon_frames_sample_index={vision_conditon_frames_sample_index}"
|
438 |
+
)
|
439 |
+
encoder_hidden_states = rearrange(
|
440 |
+
hidden_states, "(b t) hw c -> b t hw c", t=num_frames
|
441 |
+
)
|
442 |
+
# if False:
|
443 |
+
if vision_conditon_frames_sample_index is not None and num_frames > 1:
|
444 |
+
ip_hidden_states = batch_index_select(
|
445 |
+
encoder_hidden_states,
|
446 |
+
dim=1,
|
447 |
+
index=vision_conditon_frames_sample_index,
|
448 |
+
).contiguous()
|
449 |
+
if self.print_idx == 0:
|
450 |
+
logger.debug(
|
451 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor 1, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
|
452 |
+
)
|
453 |
+
#
|
454 |
+
ip_hidden_states = rearrange(
|
455 |
+
ip_hidden_states, "b t hw c -> b 1 (t hw) c"
|
456 |
+
)
|
457 |
+
ip_hidden_states = align_repeat_tensor_single_dim(
|
458 |
+
ip_hidden_states,
|
459 |
+
dim=1,
|
460 |
+
target_length=num_frames,
|
461 |
+
)
|
462 |
+
# b t hw c -> b t hw + hw c
|
463 |
+
if self.print_idx == 0:
|
464 |
+
logger.debug(
|
465 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor 2, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
|
466 |
+
)
|
467 |
+
encoder_hidden_states = torch.concat(
|
468 |
+
[encoder_hidden_states, ip_hidden_states], dim=2
|
469 |
+
)
|
470 |
+
if self.print_idx == 0:
|
471 |
+
logger.debug(
|
472 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor 3, hidden_states={hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
|
473 |
+
)
|
474 |
+
# if False:
|
475 |
+
if refer_emb is not None: # and num_frames > 1:
|
476 |
+
refer_emb = rearrange(refer_emb, "b c t h w->b 1 (t h w) c")
|
477 |
+
refer_emb = align_repeat_tensor_single_dim(
|
478 |
+
refer_emb, target_length=num_frames, dim=1
|
479 |
+
)
|
480 |
+
if self.print_idx == 0:
|
481 |
+
logger.debug(
|
482 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor4, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
|
483 |
+
)
|
484 |
+
encoder_hidden_states = torch.concat(
|
485 |
+
[encoder_hidden_states, refer_emb], dim=2
|
486 |
+
)
|
487 |
+
if self.print_idx == 0:
|
488 |
+
logger.debug(
|
489 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor5, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
|
490 |
+
)
|
491 |
+
encoder_hidden_states = rearrange(
|
492 |
+
encoder_hidden_states, "b t hw c -> (b t) hw c"
|
493 |
+
)
|
494 |
+
# vision_cond in same unet attn end
|
495 |
+
|
496 |
+
if attn.group_norm is not None:
|
497 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
498 |
+
1, 2
|
499 |
+
)
|
500 |
+
|
501 |
+
query = attn.to_q(hidden_states, scale=scale)
|
502 |
+
|
503 |
+
if encoder_hidden_states is None:
|
504 |
+
encoder_hidden_states = hidden_states
|
505 |
+
elif attn.norm_cross:
|
506 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
507 |
+
encoder_hidden_states
|
508 |
+
)
|
509 |
+
encoder_hidden_states = align_repeat_tensor_single_dim(
|
510 |
+
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
|
511 |
+
)
|
512 |
+
key = attn.to_k(encoder_hidden_states, scale=scale)
|
513 |
+
value = attn.to_v(encoder_hidden_states, scale=scale)
|
514 |
+
|
515 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
516 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
517 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
518 |
+
|
519 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
520 |
+
query,
|
521 |
+
key,
|
522 |
+
value,
|
523 |
+
attn_bias=attention_mask,
|
524 |
+
op=self.attention_op,
|
525 |
+
scale=attn.scale,
|
526 |
+
)
|
527 |
+
hidden_states = hidden_states.to(query.dtype)
|
528 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
529 |
+
|
530 |
+
# linear proj
|
531 |
+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
532 |
+
# dropout
|
533 |
+
hidden_states = attn.to_out[1](hidden_states)
|
534 |
+
|
535 |
+
if input_ndim == 4:
|
536 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
537 |
+
batch_size, channel, height, width
|
538 |
+
)
|
539 |
+
|
540 |
+
if attn.residual_connection:
|
541 |
+
hidden_states = hidden_states + residual
|
542 |
+
|
543 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
544 |
+
self.print_idx += 1
|
545 |
+
|
546 |
+
return hidden_states
|
547 |
+
|
548 |
+
|
549 |
+
@Model_Register.register
|
550 |
+
class NonParamReferenceIPXFormersAttnProcessor(
|
551 |
+
NonParamT2ISelfReferenceXFormersAttnProcessor
|
552 |
+
):
|
553 |
+
def __init__(self, attention_op: Callable[..., Any] | None = None):
|
554 |
+
super().__init__(attention_op)
|
555 |
+
|
556 |
+
|
557 |
+
@maybe_allow_in_graph
|
558 |
+
class ReferEmbFuseAttention(IPAttention):
|
559 |
+
"""使用 attention 融合 refernet 中的 emb 到 unet 对应的 latens 中
|
560 |
+
# TODO: 目前只支持 bt hw c 的融合,后续考虑增加对 视频 bhw t c、b thw c的融合
|
561 |
+
residual_connection: bool = True, 默认, 从不产生影响开始学习
|
562 |
+
|
563 |
+
use attention to fuse referencenet emb into unet latents
|
564 |
+
# TODO: by now, only support bt hw c, later consider to support bhw t c, b thw c
|
565 |
+
residual_connection: bool = True, default, start from no effect
|
566 |
+
|
567 |
+
Args:
|
568 |
+
IPAttention (_type_): _description_
|
569 |
+
"""
|
570 |
+
|
571 |
+
print_idx = 0
|
572 |
+
|
573 |
+
def __init__(
|
574 |
+
self,
|
575 |
+
query_dim: int,
|
576 |
+
cross_attention_dim: int | None = None,
|
577 |
+
heads: int = 8,
|
578 |
+
dim_head: int = 64,
|
579 |
+
dropout: float = 0,
|
580 |
+
bias=False,
|
581 |
+
upcast_attention: bool = False,
|
582 |
+
upcast_softmax: bool = False,
|
583 |
+
cross_attention_norm: str | None = None,
|
584 |
+
cross_attention_norm_num_groups: int = 32,
|
585 |
+
added_kv_proj_dim: int | None = None,
|
586 |
+
norm_num_groups: int | None = None,
|
587 |
+
spatial_norm_dim: int | None = None,
|
588 |
+
out_bias: bool = True,
|
589 |
+
scale_qk: bool = True,
|
590 |
+
only_cross_attention: bool = False,
|
591 |
+
eps: float = 0.00001,
|
592 |
+
rescale_output_factor: float = 1,
|
593 |
+
residual_connection: bool = True,
|
594 |
+
_from_deprecated_attn_block=False,
|
595 |
+
processor: AttnProcessor | None = None,
|
596 |
+
cross_attn_temporal_cond: bool = False,
|
597 |
+
image_scale: float = 1,
|
598 |
+
):
|
599 |
+
super().__init__(
|
600 |
+
query_dim,
|
601 |
+
cross_attention_dim,
|
602 |
+
heads,
|
603 |
+
dim_head,
|
604 |
+
dropout,
|
605 |
+
bias,
|
606 |
+
upcast_attention,
|
607 |
+
upcast_softmax,
|
608 |
+
cross_attention_norm,
|
609 |
+
cross_attention_norm_num_groups,
|
610 |
+
added_kv_proj_dim,
|
611 |
+
norm_num_groups,
|
612 |
+
spatial_norm_dim,
|
613 |
+
out_bias,
|
614 |
+
scale_qk,
|
615 |
+
only_cross_attention,
|
616 |
+
eps,
|
617 |
+
rescale_output_factor,
|
618 |
+
residual_connection,
|
619 |
+
_from_deprecated_attn_block,
|
620 |
+
processor,
|
621 |
+
cross_attn_temporal_cond,
|
622 |
+
image_scale,
|
623 |
+
)
|
624 |
+
self.processor = None
|
625 |
+
# 配合residual,使一开始不影响之前结果
|
626 |
+
nn.init.zeros_(self.to_out[0].weight)
|
627 |
+
nn.init.zeros_(self.to_out[0].bias)
|
628 |
+
|
629 |
+
def forward(
|
630 |
+
self,
|
631 |
+
hidden_states: torch.FloatTensor,
|
632 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
633 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
634 |
+
temb: Optional[torch.FloatTensor] = None,
|
635 |
+
scale: float = 1.0,
|
636 |
+
num_frames: int = None,
|
637 |
+
) -> torch.Tensor:
|
638 |
+
"""fuse referencenet emb b c t2 h2 w2 into unet latents b c t1 h1 w1 with attn
|
639 |
+
refer to musev/models/attention_processor.py::NonParamT2ISelfReferenceXFormersAttnProcessor
|
640 |
+
|
641 |
+
Args:
|
642 |
+
hidden_states (torch.FloatTensor): unet latents, (b t1) c h1 w1
|
643 |
+
encoder_hidden_states (Optional[torch.FloatTensor], optional): referencenet emb b c2 t2 h2 w2. Defaults to None.
|
644 |
+
attention_mask (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
|
645 |
+
temb (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
|
646 |
+
scale (float, optional): _description_. Defaults to 1.0.
|
647 |
+
num_frames (int, optional): _description_. Defaults to None.
|
648 |
+
|
649 |
+
Returns:
|
650 |
+
torch.Tensor: _description_
|
651 |
+
"""
|
652 |
+
residual = hidden_states
|
653 |
+
# start
|
654 |
+
hidden_states = rearrange(
|
655 |
+
hidden_states, "(b t) c h w -> b c t h w", t=num_frames
|
656 |
+
)
|
657 |
+
batch_size, channel, t1, height, width = hidden_states.shape
|
658 |
+
if self.print_idx == 0:
|
659 |
+
logger.debug(
|
660 |
+
f"hidden_states={hidden_states.shape},encoder_hidden_states={encoder_hidden_states.shape}"
|
661 |
+
)
|
662 |
+
# concat with hidden_states b c t1 h1 w1 in hw channel into bt (t2 + 1)hw c
|
663 |
+
encoder_hidden_states = rearrange(
|
664 |
+
encoder_hidden_states, " b c t2 h w-> b (t2 h w) c"
|
665 |
+
)
|
666 |
+
encoder_hidden_states = repeat(
|
667 |
+
encoder_hidden_states, " b t2hw c -> (b t) t2hw c", t=t1
|
668 |
+
)
|
669 |
+
hidden_states = rearrange(hidden_states, " b c t h w-> (b t) (h w) c")
|
670 |
+
# bt (t2+1)hw d
|
671 |
+
encoder_hidden_states = torch.concat(
|
672 |
+
[encoder_hidden_states, hidden_states], dim=1
|
673 |
+
)
|
674 |
+
# encoder_hidden_states = align_repeat_tensor_single_dim(
|
675 |
+
# encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
|
676 |
+
# )
|
677 |
+
# end
|
678 |
+
|
679 |
+
if self.spatial_norm is not None:
|
680 |
+
hidden_states = self.spatial_norm(hidden_states, temb)
|
681 |
+
|
682 |
+
_, key_tokens, _ = (
|
683 |
+
hidden_states.shape
|
684 |
+
if encoder_hidden_states is None
|
685 |
+
else encoder_hidden_states.shape
|
686 |
+
)
|
687 |
+
|
688 |
+
attention_mask = self.prepare_attention_mask(
|
689 |
+
attention_mask, key_tokens, batch_size
|
690 |
+
)
|
691 |
+
if attention_mask is not None:
|
692 |
+
# expand our mask's singleton query_tokens dimension:
|
693 |
+
# [batch*heads, 1, key_tokens] ->
|
694 |
+
# [batch*heads, query_tokens, key_tokens]
|
695 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
696 |
+
# [batch*heads, query_tokens, key_tokens]
|
697 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
698 |
+
_, query_tokens, _ = hidden_states.shape
|
699 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
700 |
+
|
701 |
+
if self.group_norm is not None:
|
702 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(
|
703 |
+
1, 2
|
704 |
+
)
|
705 |
+
|
706 |
+
query = self.to_q(hidden_states, scale=scale)
|
707 |
+
|
708 |
+
if encoder_hidden_states is None:
|
709 |
+
encoder_hidden_states = hidden_states
|
710 |
+
elif self.norm_cross:
|
711 |
+
encoder_hidden_states = self.norm_encoder_hidden_states(
|
712 |
+
encoder_hidden_states
|
713 |
+
)
|
714 |
+
|
715 |
+
key = self.to_k(encoder_hidden_states, scale=scale)
|
716 |
+
value = self.to_v(encoder_hidden_states, scale=scale)
|
717 |
+
|
718 |
+
query = self.head_to_batch_dim(query).contiguous()
|
719 |
+
key = self.head_to_batch_dim(key).contiguous()
|
720 |
+
value = self.head_to_batch_dim(value).contiguous()
|
721 |
+
|
722 |
+
# query: b t hw d
|
723 |
+
# key/value: bt (t1+1)hw d
|
724 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
725 |
+
query,
|
726 |
+
key,
|
727 |
+
value,
|
728 |
+
attn_bias=attention_mask,
|
729 |
+
scale=self.scale,
|
730 |
+
)
|
731 |
+
hidden_states = hidden_states.to(query.dtype)
|
732 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
733 |
+
|
734 |
+
# linear proj
|
735 |
+
hidden_states = self.to_out[0](hidden_states, scale=scale)
|
736 |
+
# dropout
|
737 |
+
hidden_states = self.to_out[1](hidden_states)
|
738 |
+
|
739 |
+
hidden_states = rearrange(
|
740 |
+
hidden_states,
|
741 |
+
"bt (h w) c-> bt c h w",
|
742 |
+
h=height,
|
743 |
+
w=width,
|
744 |
+
)
|
745 |
+
if self.residual_connection:
|
746 |
+
hidden_states = hidden_states + residual
|
747 |
+
|
748 |
+
hidden_states = hidden_states / self.rescale_output_factor
|
749 |
+
self.print_idx += 1
|
750 |
+
return hidden_states
|
musev/models/controlnet.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
2 |
+
import warnings
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
9 |
+
import PIL
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.init as init
|
14 |
+
from diffusers.models.controlnet import ControlNetModel
|
15 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
16 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
|
17 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
18 |
+
|
19 |
+
|
20 |
+
class ControlnetPredictor(object):
|
21 |
+
def __init__(self, controlnet_model_path: str, *args, **kwargs):
|
22 |
+
"""Controlnet 推断函数,用于提取 controlnet backbone的emb,避免训练时重复抽取
|
23 |
+
Controlnet inference predictor, used to extract the emb of the controlnet backbone to avoid repeated extraction during training
|
24 |
+
Args:
|
25 |
+
controlnet_model_path (str): controlnet 模型路径. controlnet model path.
|
26 |
+
"""
|
27 |
+
super(ControlnetPredictor, self).__init__(*args, **kwargs)
|
28 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
29 |
+
controlnet_model_path,
|
30 |
+
)
|
31 |
+
|
32 |
+
def prepare_image(
|
33 |
+
self,
|
34 |
+
image, # b c t h w
|
35 |
+
width,
|
36 |
+
height,
|
37 |
+
batch_size,
|
38 |
+
num_images_per_prompt,
|
39 |
+
device,
|
40 |
+
dtype,
|
41 |
+
do_classifier_free_guidance=False,
|
42 |
+
guess_mode=False,
|
43 |
+
):
|
44 |
+
if height is None:
|
45 |
+
height = image.shape[-2]
|
46 |
+
if width is None:
|
47 |
+
width = image.shape[-1]
|
48 |
+
width, height = (
|
49 |
+
x - x % self.control_image_processor.vae_scale_factor
|
50 |
+
for x in (width, height)
|
51 |
+
)
|
52 |
+
image = rearrange(image, "b c t h w-> (b t) c h w")
|
53 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
|
54 |
+
|
55 |
+
image = (
|
56 |
+
torch.nn.functional.interpolate(
|
57 |
+
image,
|
58 |
+
size=(height, width),
|
59 |
+
mode="bilinear",
|
60 |
+
),
|
61 |
+
)
|
62 |
+
|
63 |
+
do_normalize = self.control_image_processor.config.do_normalize
|
64 |
+
if image.min() < 0:
|
65 |
+
warnings.warn(
|
66 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
67 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
68 |
+
FutureWarning,
|
69 |
+
)
|
70 |
+
do_normalize = False
|
71 |
+
|
72 |
+
if do_normalize:
|
73 |
+
image = self.control_image_processor.normalize(image)
|
74 |
+
|
75 |
+
image_batch_size = image.shape[0]
|
76 |
+
|
77 |
+
if image_batch_size == 1:
|
78 |
+
repeat_by = batch_size
|
79 |
+
else:
|
80 |
+
# image batch size is the same as prompt batch size
|
81 |
+
repeat_by = num_images_per_prompt
|
82 |
+
|
83 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
84 |
+
|
85 |
+
image = image.to(device=device, dtype=dtype)
|
86 |
+
|
87 |
+
if do_classifier_free_guidance and not guess_mode:
|
88 |
+
image = torch.cat([image] * 2)
|
89 |
+
|
90 |
+
return image
|
91 |
+
|
92 |
+
@torch.no_grad()
|
93 |
+
def __call__(
|
94 |
+
self,
|
95 |
+
batch_size: int,
|
96 |
+
device: str,
|
97 |
+
dtype: torch.dtype,
|
98 |
+
timesteps: List[float],
|
99 |
+
i: int,
|
100 |
+
scheduler: KarrasDiffusionSchedulers,
|
101 |
+
prompt_embeds: torch.Tensor,
|
102 |
+
do_classifier_free_guidance: bool = False,
|
103 |
+
# 2b co t ho wo
|
104 |
+
latent_model_input: torch.Tensor = None,
|
105 |
+
# b co t ho wo
|
106 |
+
latents: torch.Tensor = None,
|
107 |
+
# b c t h w
|
108 |
+
image: Union[
|
109 |
+
torch.FloatTensor,
|
110 |
+
PIL.Image.Image,
|
111 |
+
np.ndarray,
|
112 |
+
List[torch.FloatTensor],
|
113 |
+
List[PIL.Image.Image],
|
114 |
+
List[np.ndarray],
|
115 |
+
] = None,
|
116 |
+
# b c t(1) hi wi
|
117 |
+
controlnet_condition_frames: Optional[torch.FloatTensor] = None,
|
118 |
+
# b c t ho wo
|
119 |
+
controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None,
|
120 |
+
# b c t(1) ho wo
|
121 |
+
controlnet_condition_latents: Optional[torch.FloatTensor] = None,
|
122 |
+
height: Optional[int] = None,
|
123 |
+
width: Optional[int] = None,
|
124 |
+
num_videos_per_prompt: Optional[int] = 1,
|
125 |
+
return_dict: bool = True,
|
126 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
127 |
+
guess_mode: bool = False,
|
128 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
129 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
130 |
+
latent_index: torch.LongTensor = None,
|
131 |
+
vision_condition_latent_index: torch.LongTensor = None,
|
132 |
+
**kwargs,
|
133 |
+
):
|
134 |
+
assert (
|
135 |
+
image is None and controlnet_latents is None
|
136 |
+
), "should set one of image and controlnet_latents"
|
137 |
+
|
138 |
+
controlnet = (
|
139 |
+
self.controlnet._orig_mod
|
140 |
+
if is_compiled_module(self.controlnet)
|
141 |
+
else self.controlnet
|
142 |
+
)
|
143 |
+
|
144 |
+
# align format for control guidance
|
145 |
+
if not isinstance(control_guidance_start, list) and isinstance(
|
146 |
+
control_guidance_end, list
|
147 |
+
):
|
148 |
+
control_guidance_start = len(control_guidance_end) * [
|
149 |
+
control_guidance_start
|
150 |
+
]
|
151 |
+
elif not isinstance(control_guidance_end, list) and isinstance(
|
152 |
+
control_guidance_start, list
|
153 |
+
):
|
154 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
155 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(
|
156 |
+
control_guidance_end, list
|
157 |
+
):
|
158 |
+
mult = (
|
159 |
+
len(controlnet.nets)
|
160 |
+
if isinstance(controlnet, MultiControlNetModel)
|
161 |
+
else 1
|
162 |
+
)
|
163 |
+
control_guidance_start, control_guidance_end = mult * [
|
164 |
+
control_guidance_start
|
165 |
+
], mult * [control_guidance_end]
|
166 |
+
|
167 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(
|
168 |
+
controlnet_conditioning_scale, float
|
169 |
+
):
|
170 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
|
171 |
+
controlnet.nets
|
172 |
+
)
|
173 |
+
|
174 |
+
global_pool_conditions = (
|
175 |
+
controlnet.config.global_pool_conditions
|
176 |
+
if isinstance(controlnet, ControlNetModel)
|
177 |
+
else controlnet.nets[0].config.global_pool_conditions
|
178 |
+
)
|
179 |
+
guess_mode = guess_mode or global_pool_conditions
|
180 |
+
|
181 |
+
# 4. Prepare image
|
182 |
+
if isinstance(controlnet, ControlNetModel):
|
183 |
+
if (
|
184 |
+
controlnet_latents is not None
|
185 |
+
and controlnet_condition_latents is not None
|
186 |
+
):
|
187 |
+
if isinstance(controlnet_latents, np.ndarray):
|
188 |
+
controlnet_latents = torch.from_numpy(controlnet_latents)
|
189 |
+
if isinstance(controlnet_condition_latents, np.ndarray):
|
190 |
+
controlnet_condition_latents = torch.from_numpy(
|
191 |
+
controlnet_condition_latents
|
192 |
+
)
|
193 |
+
# TODO:使用index进行concat
|
194 |
+
controlnet_latents = torch.concat(
|
195 |
+
[controlnet_condition_latents, controlnet_latents], dim=2
|
196 |
+
)
|
197 |
+
if not guess_mode and do_classifier_free_guidance:
|
198 |
+
controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0)
|
199 |
+
controlnet_latents = rearrange(
|
200 |
+
controlnet_latents, "b c t h w->(b t) c h w"
|
201 |
+
)
|
202 |
+
controlnet_latents = controlnet_latents.to(device=device, dtype=dtype)
|
203 |
+
else:
|
204 |
+
# TODO:使用index进行concat
|
205 |
+
# TODO: concat with index
|
206 |
+
if controlnet_condition_frames is not None:
|
207 |
+
if isinstance(controlnet_condition_frames, np.ndarray):
|
208 |
+
image = np.concatenate(
|
209 |
+
[controlnet_condition_frames, image], axis=2
|
210 |
+
)
|
211 |
+
image = self.prepare_image(
|
212 |
+
image=image,
|
213 |
+
width=width,
|
214 |
+
height=height,
|
215 |
+
batch_size=batch_size * num_videos_per_prompt,
|
216 |
+
num_images_per_prompt=num_videos_per_prompt,
|
217 |
+
device=device,
|
218 |
+
dtype=controlnet.dtype,
|
219 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
220 |
+
guess_mode=guess_mode,
|
221 |
+
)
|
222 |
+
height, width = image.shape[-2:]
|
223 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
224 |
+
images = []
|
225 |
+
# TODO: 支持直接使用controlnet_latent而不是frames
|
226 |
+
# TODO: support using controlnet_latent directly instead of frames
|
227 |
+
if controlnet_latents is not None:
|
228 |
+
raise NotImplementedError
|
229 |
+
else:
|
230 |
+
for i, image_ in enumerate(image):
|
231 |
+
if controlnet_condition_frames is not None and isinstance(
|
232 |
+
controlnet_condition_frames, list
|
233 |
+
):
|
234 |
+
if isinstance(controlnet_condition_frames[i], np.ndarray):
|
235 |
+
image_ = np.concatenate(
|
236 |
+
[controlnet_condition_frames[i], image_], axis=2
|
237 |
+
)
|
238 |
+
image_ = self.prepare_image(
|
239 |
+
image=image_,
|
240 |
+
width=width,
|
241 |
+
height=height,
|
242 |
+
batch_size=batch_size * num_videos_per_prompt,
|
243 |
+
num_images_per_prompt=num_videos_per_prompt,
|
244 |
+
device=device,
|
245 |
+
dtype=controlnet.dtype,
|
246 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
247 |
+
guess_mode=guess_mode,
|
248 |
+
)
|
249 |
+
|
250 |
+
images.append(image_)
|
251 |
+
|
252 |
+
image = images
|
253 |
+
height, width = image[0].shape[-2:]
|
254 |
+
else:
|
255 |
+
assert False
|
256 |
+
|
257 |
+
# 7.1 Create tensor stating which controlnets to keep
|
258 |
+
controlnet_keep = []
|
259 |
+
for i in range(len(timesteps)):
|
260 |
+
keeps = [
|
261 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
262 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
263 |
+
]
|
264 |
+
controlnet_keep.append(
|
265 |
+
keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
|
266 |
+
)
|
267 |
+
|
268 |
+
t = timesteps[i]
|
269 |
+
|
270 |
+
# controlnet(s) inference
|
271 |
+
if guess_mode and do_classifier_free_guidance:
|
272 |
+
# Infer ControlNet only for the conditional batch.
|
273 |
+
control_model_input = latents
|
274 |
+
control_model_input = scheduler.scale_model_input(control_model_input, t)
|
275 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
276 |
+
else:
|
277 |
+
control_model_input = latent_model_input
|
278 |
+
controlnet_prompt_embeds = prompt_embeds
|
279 |
+
if isinstance(controlnet_keep[i], list):
|
280 |
+
cond_scale = [
|
281 |
+
c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])
|
282 |
+
]
|
283 |
+
else:
|
284 |
+
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
|
285 |
+
control_model_input_reshape = rearrange(
|
286 |
+
control_model_input, "b c t h w -> (b t) c h w"
|
287 |
+
)
|
288 |
+
encoder_hidden_states_repeat = repeat(
|
289 |
+
controlnet_prompt_embeds,
|
290 |
+
"b n q->(b t) n q",
|
291 |
+
t=control_model_input.shape[2],
|
292 |
+
)
|
293 |
+
|
294 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
295 |
+
control_model_input_reshape,
|
296 |
+
t,
|
297 |
+
encoder_hidden_states_repeat,
|
298 |
+
controlnet_cond=image,
|
299 |
+
controlnet_cond_latents=controlnet_latents,
|
300 |
+
conditioning_scale=cond_scale,
|
301 |
+
guess_mode=guess_mode,
|
302 |
+
return_dict=False,
|
303 |
+
)
|
304 |
+
|
305 |
+
return down_block_res_samples, mid_block_res_sample
|
306 |
+
|
307 |
+
|
308 |
+
class InflatedConv3d(nn.Conv2d):
|
309 |
+
def forward(self, x):
|
310 |
+
video_length = x.shape[2]
|
311 |
+
|
312 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
313 |
+
x = super().forward(x)
|
314 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
315 |
+
|
316 |
+
return x
|
317 |
+
|
318 |
+
|
319 |
+
def zero_module(module):
|
320 |
+
# Zero out the parameters of a module and return it.
|
321 |
+
for p in module.parameters():
|
322 |
+
p.detach().zero_()
|
323 |
+
return module
|
324 |
+
|
325 |
+
|
326 |
+
class PoseGuider(ModelMixin):
|
327 |
+
def __init__(
|
328 |
+
self,
|
329 |
+
conditioning_embedding_channels: int,
|
330 |
+
conditioning_channels: int = 3,
|
331 |
+
block_out_channels: Tuple[int] = (16, 32, 64, 128),
|
332 |
+
):
|
333 |
+
super().__init__()
|
334 |
+
self.conv_in = InflatedConv3d(
|
335 |
+
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
|
336 |
+
)
|
337 |
+
|
338 |
+
self.blocks = nn.ModuleList([])
|
339 |
+
|
340 |
+
for i in range(len(block_out_channels) - 1):
|
341 |
+
channel_in = block_out_channels[i]
|
342 |
+
channel_out = block_out_channels[i + 1]
|
343 |
+
self.blocks.append(
|
344 |
+
InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
|
345 |
+
)
|
346 |
+
self.blocks.append(
|
347 |
+
InflatedConv3d(
|
348 |
+
channel_in, channel_out, kernel_size=3, padding=1, stride=2
|
349 |
+
)
|
350 |
+
)
|
351 |
+
|
352 |
+
self.conv_out = zero_module(
|
353 |
+
InflatedConv3d(
|
354 |
+
block_out_channels[-1],
|
355 |
+
conditioning_embedding_channels,
|
356 |
+
kernel_size=3,
|
357 |
+
padding=1,
|
358 |
+
)
|
359 |
+
)
|
360 |
+
|
361 |
+
def forward(self, conditioning):
|
362 |
+
embedding = self.conv_in(conditioning)
|
363 |
+
embedding = F.silu(embedding)
|
364 |
+
|
365 |
+
for block in self.blocks:
|
366 |
+
embedding = block(embedding)
|
367 |
+
embedding = F.silu(embedding)
|
368 |
+
|
369 |
+
embedding = self.conv_out(embedding)
|
370 |
+
|
371 |
+
return embedding
|
372 |
+
|
373 |
+
@classmethod
|
374 |
+
def from_pretrained(
|
375 |
+
cls,
|
376 |
+
pretrained_model_path,
|
377 |
+
conditioning_embedding_channels: int,
|
378 |
+
conditioning_channels: int = 3,
|
379 |
+
block_out_channels: Tuple[int] = (16, 32, 64, 128),
|
380 |
+
):
|
381 |
+
if not os.path.exists(pretrained_model_path):
|
382 |
+
print(f"There is no model file in {pretrained_model_path}")
|
383 |
+
print(
|
384 |
+
f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ..."
|
385 |
+
)
|
386 |
+
|
387 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
388 |
+
model = PoseGuider(
|
389 |
+
conditioning_embedding_channels=conditioning_embedding_channels,
|
390 |
+
conditioning_channels=conditioning_channels,
|
391 |
+
block_out_channels=block_out_channels,
|
392 |
+
)
|
393 |
+
|
394 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
395 |
+
# print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
396 |
+
params = [p.numel() for n, p in model.named_parameters()]
|
397 |
+
print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
|
398 |
+
|
399 |
+
return model
|
musev/models/embeddings.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from einops import rearrange
|
16 |
+
import torch
|
17 |
+
from torch.nn import functional as F
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from diffusers.models.embeddings import get_2d_sincos_pos_embed_from_grid
|
21 |
+
|
22 |
+
|
23 |
+
# ref diffusers.models.embeddings.get_2d_sincos_pos_embed
|
24 |
+
def get_2d_sincos_pos_embed(
|
25 |
+
embed_dim,
|
26 |
+
grid_size_w,
|
27 |
+
grid_size_h,
|
28 |
+
cls_token=False,
|
29 |
+
extra_tokens=0,
|
30 |
+
norm_length: bool = False,
|
31 |
+
max_length: float = 2048,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
35 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
36 |
+
"""
|
37 |
+
if norm_length and grid_size_h <= max_length and grid_size_w <= max_length:
|
38 |
+
grid_h = np.linspace(0, max_length, grid_size_h)
|
39 |
+
grid_w = np.linspace(0, max_length, grid_size_w)
|
40 |
+
else:
|
41 |
+
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
42 |
+
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
43 |
+
grid = np.meshgrid(grid_h, grid_w) # here h goes first
|
44 |
+
grid = np.stack(grid, axis=0)
|
45 |
+
|
46 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
47 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
48 |
+
if cls_token and extra_tokens > 0:
|
49 |
+
pos_embed = np.concatenate(
|
50 |
+
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
51 |
+
)
|
52 |
+
return pos_embed
|
53 |
+
|
54 |
+
|
55 |
+
def resize_spatial_position_emb(
|
56 |
+
emb: torch.Tensor,
|
57 |
+
height: int,
|
58 |
+
width: int,
|
59 |
+
scale: float = None,
|
60 |
+
target_height: int = None,
|
61 |
+
target_width: int = None,
|
62 |
+
) -> torch.Tensor:
|
63 |
+
"""_summary_
|
64 |
+
|
65 |
+
Args:
|
66 |
+
emb (torch.Tensor): b ( h w) d
|
67 |
+
height (int): _description_
|
68 |
+
width (int): _description_
|
69 |
+
scale (float, optional): _description_. Defaults to None.
|
70 |
+
target_height (int, optional): _description_. Defaults to None.
|
71 |
+
target_width (int, optional): _description_. Defaults to None.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: b (target_height target_width) d
|
75 |
+
"""
|
76 |
+
if scale is not None:
|
77 |
+
target_height = int(height * scale)
|
78 |
+
target_width = int(width * scale)
|
79 |
+
emb = rearrange(emb, "(h w) (b d) ->b d h w", h=height, b=1)
|
80 |
+
emb = F.interpolate(
|
81 |
+
emb,
|
82 |
+
size=(target_height, target_width),
|
83 |
+
mode="bicubic",
|
84 |
+
align_corners=False,
|
85 |
+
)
|
86 |
+
emb = rearrange(emb, "b d h w-> (h w) (b d)")
|
87 |
+
return emb
|
musev/models/facein_loader.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.modeling_utils import load_state_dict
|
31 |
+
from diffusers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
|
36 |
+
from mmcm.vision.feature_extractor.clip_vision_extractor import (
|
37 |
+
ImageClipVisionFeatureExtractor,
|
38 |
+
ImageClipVisionFeatureExtractorV2,
|
39 |
+
)
|
40 |
+
from mmcm.vision.feature_extractor.insight_face_extractor import InsightFaceExtractor
|
41 |
+
|
42 |
+
from ip_adapter.resampler import Resampler
|
43 |
+
from ip_adapter.ip_adapter import ImageProjModel
|
44 |
+
|
45 |
+
from .unet_loader import update_unet_with_sd
|
46 |
+
from .unet_3d_condition import UNet3DConditionModel
|
47 |
+
from .ip_adapter_loader import ip_adapter_keys_list
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
50 |
+
|
51 |
+
|
52 |
+
# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
|
53 |
+
unet_keys_list = [
|
54 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
55 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
56 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
57 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
58 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
59 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
60 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
61 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
62 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
63 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
64 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
65 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
66 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
67 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
68 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
69 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
70 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
71 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
72 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
73 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
74 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
75 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
76 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
77 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
78 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
79 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
80 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
81 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
82 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
83 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
84 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
85 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
86 |
+
]
|
87 |
+
|
88 |
+
|
89 |
+
UNET2IPAadapter_Keys_MAPIING = {
|
90 |
+
k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
def load_facein_extractor_and_proj_by_name(
|
95 |
+
model_name: str,
|
96 |
+
ip_ckpt: Tuple[str, nn.Module],
|
97 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
98 |
+
cross_attention_dim: int = 768,
|
99 |
+
clip_embeddings_dim: int = 512,
|
100 |
+
clip_extra_context_tokens: int = 1,
|
101 |
+
ip_scale: float = 0.0,
|
102 |
+
dtype: torch.dtype = torch.float16,
|
103 |
+
device: str = "cuda",
|
104 |
+
unet: nn.Module = None,
|
105 |
+
) -> nn.Module:
|
106 |
+
pass
|
107 |
+
|
108 |
+
|
109 |
+
def update_unet_facein_cross_attn_param(
|
110 |
+
unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
|
111 |
+
) -> None:
|
112 |
+
"""use independent ip_adapter attn 中的 to_k, to_v in unet
|
113 |
+
ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']的字典
|
114 |
+
|
115 |
+
|
116 |
+
Args:
|
117 |
+
unet (UNet3DConditionModel): _description_
|
118 |
+
ip_adapter_state_dict (Dict): _description_
|
119 |
+
"""
|
120 |
+
pass
|
musev/models/ip_adapter_face_loader.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.modeling_utils import load_state_dict
|
31 |
+
from diffusers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
|
36 |
+
from ip_adapter.resampler import Resampler
|
37 |
+
from ip_adapter.ip_adapter import ImageProjModel
|
38 |
+
from ip_adapter.ip_adapter_faceid import ProjPlusModel, MLPProjModel
|
39 |
+
|
40 |
+
from mmcm.vision.feature_extractor.clip_vision_extractor import (
|
41 |
+
ImageClipVisionFeatureExtractor,
|
42 |
+
ImageClipVisionFeatureExtractorV2,
|
43 |
+
)
|
44 |
+
from mmcm.vision.feature_extractor.insight_face_extractor import (
|
45 |
+
InsightFaceExtractorNormEmb,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
from .unet_loader import update_unet_with_sd
|
50 |
+
from .unet_3d_condition import UNet3DConditionModel
|
51 |
+
from .ip_adapter_loader import ip_adapter_keys_list
|
52 |
+
|
53 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
54 |
+
|
55 |
+
|
56 |
+
# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
|
57 |
+
unet_keys_list = [
|
58 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
59 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
60 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
61 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
62 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
63 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
64 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
65 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
66 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
67 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
68 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
69 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
70 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
71 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
72 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
73 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
74 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
75 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
76 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
77 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
78 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
79 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
80 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
81 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
82 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
83 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
84 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
85 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
86 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
87 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
88 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
89 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
90 |
+
]
|
91 |
+
|
92 |
+
|
93 |
+
UNET2IPAadapter_Keys_MAPIING = {
|
94 |
+
k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
def load_ip_adapter_face_extractor_and_proj_by_name(
|
99 |
+
model_name: str,
|
100 |
+
ip_ckpt: Tuple[str, nn.Module],
|
101 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
102 |
+
cross_attention_dim: int = 768,
|
103 |
+
clip_embeddings_dim: int = 1024,
|
104 |
+
clip_extra_context_tokens: int = 4,
|
105 |
+
ip_scale: float = 0.0,
|
106 |
+
dtype: torch.dtype = torch.float16,
|
107 |
+
device: str = "cuda",
|
108 |
+
unet: nn.Module = None,
|
109 |
+
) -> nn.Module:
|
110 |
+
if model_name == "IPAdapterFaceID":
|
111 |
+
if ip_image_encoder is not None:
|
112 |
+
ip_adapter_face_emb_extractor = InsightFaceExtractorNormEmb(
|
113 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
114 |
+
dtype=dtype,
|
115 |
+
device=device,
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
ip_adapter_face_emb_extractor = None
|
119 |
+
ip_adapter_image_proj = MLPProjModel(
|
120 |
+
cross_attention_dim=cross_attention_dim,
|
121 |
+
id_embeddings_dim=clip_embeddings_dim,
|
122 |
+
num_tokens=clip_extra_context_tokens,
|
123 |
+
).to(device, dtype=dtype)
|
124 |
+
else:
|
125 |
+
raise ValueError(
|
126 |
+
f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, IPAdapterFaceID"
|
127 |
+
)
|
128 |
+
ip_adapter_state_dict = torch.load(
|
129 |
+
ip_ckpt,
|
130 |
+
map_location="cpu",
|
131 |
+
)
|
132 |
+
ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
|
133 |
+
if unet is not None and "ip_adapter" in ip_adapter_state_dict:
|
134 |
+
update_unet_ip_adapter_cross_attn_param(
|
135 |
+
unet,
|
136 |
+
ip_adapter_state_dict["ip_adapter"],
|
137 |
+
)
|
138 |
+
logger.info(
|
139 |
+
f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
|
140 |
+
)
|
141 |
+
return (
|
142 |
+
ip_adapter_face_emb_extractor,
|
143 |
+
ip_adapter_image_proj,
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
def update_unet_ip_adapter_cross_attn_param(
|
148 |
+
unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
|
149 |
+
) -> None:
|
150 |
+
"""use independent ip_adapter attn 中的 to_k, to_v in unet
|
151 |
+
ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']
|
152 |
+
|
153 |
+
|
154 |
+
Args:
|
155 |
+
unet (UNet3DConditionModel): _description_
|
156 |
+
ip_adapter_state_dict (Dict): _description_
|
157 |
+
"""
|
158 |
+
unet_spatial_cross_atnns = unet.spatial_cross_attns[0]
|
159 |
+
unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns}
|
160 |
+
for i, (unet_key_more, ip_adapter_key) in enumerate(
|
161 |
+
UNET2IPAadapter_Keys_MAPIING.items()
|
162 |
+
):
|
163 |
+
ip_adapter_value = ip_adapter_state_dict[ip_adapter_key]
|
164 |
+
unet_key_more_spit = unet_key_more.split(".")
|
165 |
+
unet_key = ".".join(unet_key_more_spit[:-3])
|
166 |
+
suffix = ".".join(unet_key_more_spit[-3:])
|
167 |
+
logger.debug(
|
168 |
+
f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}",
|
169 |
+
)
|
170 |
+
if ".ip_adapter_face_to_k" in suffix:
|
171 |
+
with torch.no_grad():
|
172 |
+
unet_spatial_cross_atnns_dct[
|
173 |
+
unet_key
|
174 |
+
].ip_adapter_face_to_k_ip.weight.copy_(ip_adapter_value.data)
|
175 |
+
else:
|
176 |
+
with torch.no_grad():
|
177 |
+
unet_spatial_cross_atnns_dct[
|
178 |
+
unet_key
|
179 |
+
].ip_adapter_face_to_v_ip.weight.copy_(ip_adapter_value.data)
|
musev/models/ip_adapter_loader.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.modeling_utils import load_state_dict
|
31 |
+
from diffusers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
|
36 |
+
from mmcm.vision.feature_extractor import clip_vision_extractor
|
37 |
+
from mmcm.vision.feature_extractor.clip_vision_extractor import (
|
38 |
+
ImageClipVisionFeatureExtractor,
|
39 |
+
ImageClipVisionFeatureExtractorV2,
|
40 |
+
VerstailSDLastHiddenState2ImageEmb,
|
41 |
+
)
|
42 |
+
|
43 |
+
from ip_adapter.resampler import Resampler
|
44 |
+
from ip_adapter.ip_adapter import ImageProjModel
|
45 |
+
|
46 |
+
from .unet_loader import update_unet_with_sd
|
47 |
+
from .unet_3d_condition import UNet3DConditionModel
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
50 |
+
|
51 |
+
|
52 |
+
def load_vision_clip_encoder_by_name(
|
53 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
54 |
+
dtype: torch.dtype = torch.float16,
|
55 |
+
device: str = "cuda",
|
56 |
+
vision_clip_extractor_class_name: str = None,
|
57 |
+
) -> nn.Module:
|
58 |
+
if vision_clip_extractor_class_name is not None:
|
59 |
+
vision_clip_extractor = getattr(
|
60 |
+
clip_vision_extractor, vision_clip_extractor_class_name
|
61 |
+
)(
|
62 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
63 |
+
dtype=dtype,
|
64 |
+
device=device,
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
vision_clip_extractor = None
|
68 |
+
return vision_clip_extractor
|
69 |
+
|
70 |
+
|
71 |
+
def load_ip_adapter_image_proj_by_name(
|
72 |
+
model_name: str,
|
73 |
+
ip_ckpt: Tuple[str, nn.Module] = None,
|
74 |
+
cross_attention_dim: int = 768,
|
75 |
+
clip_embeddings_dim: int = 1024,
|
76 |
+
clip_extra_context_tokens: int = 4,
|
77 |
+
ip_scale: float = 0.0,
|
78 |
+
dtype: torch.dtype = torch.float16,
|
79 |
+
device: str = "cuda",
|
80 |
+
unet: nn.Module = None,
|
81 |
+
vision_clip_extractor_class_name: str = None,
|
82 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
83 |
+
) -> nn.Module:
|
84 |
+
if model_name in [
|
85 |
+
"IPAdapter",
|
86 |
+
"musev_referencenet",
|
87 |
+
"musev_referencenet_pose",
|
88 |
+
]:
|
89 |
+
ip_adapter_image_proj = ImageProjModel(
|
90 |
+
cross_attention_dim=cross_attention_dim,
|
91 |
+
clip_embeddings_dim=clip_embeddings_dim,
|
92 |
+
clip_extra_context_tokens=clip_extra_context_tokens,
|
93 |
+
)
|
94 |
+
|
95 |
+
elif model_name == "IPAdapterPlus":
|
96 |
+
vision_clip_extractor = ImageClipVisionFeatureExtractorV2(
|
97 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
98 |
+
dtype=dtype,
|
99 |
+
device=device,
|
100 |
+
)
|
101 |
+
ip_adapter_image_proj = Resampler(
|
102 |
+
dim=cross_attention_dim,
|
103 |
+
depth=4,
|
104 |
+
dim_head=64,
|
105 |
+
heads=12,
|
106 |
+
num_queries=clip_extra_context_tokens,
|
107 |
+
embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size,
|
108 |
+
output_dim=cross_attention_dim,
|
109 |
+
ff_mult=4,
|
110 |
+
)
|
111 |
+
elif model_name in [
|
112 |
+
"VerstailSDLastHiddenState2ImageEmb",
|
113 |
+
"OriginLastHiddenState2ImageEmbd",
|
114 |
+
"OriginLastHiddenState2Poolout",
|
115 |
+
]:
|
116 |
+
ip_adapter_image_proj = getattr(
|
117 |
+
clip_vision_extractor, model_name
|
118 |
+
).from_pretrained(ip_image_encoder)
|
119 |
+
else:
|
120 |
+
raise ValueError(
|
121 |
+
f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, VerstailSDLastHiddenState2ImageEmb"
|
122 |
+
)
|
123 |
+
if ip_ckpt is not None:
|
124 |
+
ip_adapter_state_dict = torch.load(
|
125 |
+
ip_ckpt,
|
126 |
+
map_location="cpu",
|
127 |
+
)
|
128 |
+
ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
|
129 |
+
if (
|
130 |
+
unet is not None
|
131 |
+
and unet.ip_adapter_cross_attn
|
132 |
+
and "ip_adapter" in ip_adapter_state_dict
|
133 |
+
):
|
134 |
+
update_unet_ip_adapter_cross_attn_param(
|
135 |
+
unet, ip_adapter_state_dict["ip_adapter"]
|
136 |
+
)
|
137 |
+
logger.info(
|
138 |
+
f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
|
139 |
+
)
|
140 |
+
return ip_adapter_image_proj
|
141 |
+
|
142 |
+
|
143 |
+
def load_ip_adapter_vision_clip_encoder_by_name(
|
144 |
+
model_name: str,
|
145 |
+
ip_ckpt: Tuple[str, nn.Module],
|
146 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
147 |
+
cross_attention_dim: int = 768,
|
148 |
+
clip_embeddings_dim: int = 1024,
|
149 |
+
clip_extra_context_tokens: int = 4,
|
150 |
+
ip_scale: float = 0.0,
|
151 |
+
dtype: torch.dtype = torch.float16,
|
152 |
+
device: str = "cuda",
|
153 |
+
unet: nn.Module = None,
|
154 |
+
vision_clip_extractor_class_name: str = None,
|
155 |
+
) -> nn.Module:
|
156 |
+
if vision_clip_extractor_class_name is not None:
|
157 |
+
vision_clip_extractor = getattr(
|
158 |
+
clip_vision_extractor, vision_clip_extractor_class_name
|
159 |
+
)(
|
160 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
161 |
+
dtype=dtype,
|
162 |
+
device=device,
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
vision_clip_extractor = None
|
166 |
+
if model_name in [
|
167 |
+
"IPAdapter",
|
168 |
+
"musev_referencenet",
|
169 |
+
]:
|
170 |
+
if ip_image_encoder is not None:
|
171 |
+
if vision_clip_extractor_class_name is None:
|
172 |
+
vision_clip_extractor = ImageClipVisionFeatureExtractor(
|
173 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
174 |
+
dtype=dtype,
|
175 |
+
device=device,
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
vision_clip_extractor = None
|
179 |
+
ip_adapter_image_proj = ImageProjModel(
|
180 |
+
cross_attention_dim=cross_attention_dim,
|
181 |
+
clip_embeddings_dim=clip_embeddings_dim,
|
182 |
+
clip_extra_context_tokens=clip_extra_context_tokens,
|
183 |
+
)
|
184 |
+
|
185 |
+
elif model_name == "IPAdapterPlus":
|
186 |
+
if ip_image_encoder is not None:
|
187 |
+
if vision_clip_extractor_class_name is None:
|
188 |
+
vision_clip_extractor = ImageClipVisionFeatureExtractorV2(
|
189 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
190 |
+
dtype=dtype,
|
191 |
+
device=device,
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
vision_clip_extractor = None
|
195 |
+
ip_adapter_image_proj = Resampler(
|
196 |
+
dim=cross_attention_dim,
|
197 |
+
depth=4,
|
198 |
+
dim_head=64,
|
199 |
+
heads=12,
|
200 |
+
num_queries=clip_extra_context_tokens,
|
201 |
+
embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size,
|
202 |
+
output_dim=cross_attention_dim,
|
203 |
+
ff_mult=4,
|
204 |
+
).to(dtype=torch.float16)
|
205 |
+
else:
|
206 |
+
raise ValueError(
|
207 |
+
f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus"
|
208 |
+
)
|
209 |
+
ip_adapter_state_dict = torch.load(
|
210 |
+
ip_ckpt,
|
211 |
+
map_location="cpu",
|
212 |
+
)
|
213 |
+
ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
|
214 |
+
if (
|
215 |
+
unet is not None
|
216 |
+
and unet.ip_adapter_cross_attn
|
217 |
+
and "ip_adapter" in ip_adapter_state_dict
|
218 |
+
):
|
219 |
+
update_unet_ip_adapter_cross_attn_param(
|
220 |
+
unet, ip_adapter_state_dict["ip_adapter"]
|
221 |
+
)
|
222 |
+
logger.info(
|
223 |
+
f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
|
224 |
+
)
|
225 |
+
return (
|
226 |
+
vision_clip_extractor,
|
227 |
+
ip_adapter_image_proj,
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
|
232 |
+
unet_keys_list = [
|
233 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
234 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
235 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
236 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
237 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
238 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
239 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
240 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
241 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
242 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
243 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
244 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
245 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
246 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
247 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
248 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
249 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
250 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
251 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
252 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
253 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
254 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
255 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
256 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
257 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
258 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
259 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
260 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
261 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
262 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
263 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
264 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
265 |
+
]
|
266 |
+
|
267 |
+
|
268 |
+
ip_adapter_keys_list = [
|
269 |
+
"1.to_k_ip.weight",
|
270 |
+
"1.to_v_ip.weight",
|
271 |
+
"3.to_k_ip.weight",
|
272 |
+
"3.to_v_ip.weight",
|
273 |
+
"5.to_k_ip.weight",
|
274 |
+
"5.to_v_ip.weight",
|
275 |
+
"7.to_k_ip.weight",
|
276 |
+
"7.to_v_ip.weight",
|
277 |
+
"9.to_k_ip.weight",
|
278 |
+
"9.to_v_ip.weight",
|
279 |
+
"11.to_k_ip.weight",
|
280 |
+
"11.to_v_ip.weight",
|
281 |
+
"13.to_k_ip.weight",
|
282 |
+
"13.to_v_ip.weight",
|
283 |
+
"15.to_k_ip.weight",
|
284 |
+
"15.to_v_ip.weight",
|
285 |
+
"17.to_k_ip.weight",
|
286 |
+
"17.to_v_ip.weight",
|
287 |
+
"19.to_k_ip.weight",
|
288 |
+
"19.to_v_ip.weight",
|
289 |
+
"21.to_k_ip.weight",
|
290 |
+
"21.to_v_ip.weight",
|
291 |
+
"23.to_k_ip.weight",
|
292 |
+
"23.to_v_ip.weight",
|
293 |
+
"25.to_k_ip.weight",
|
294 |
+
"25.to_v_ip.weight",
|
295 |
+
"27.to_k_ip.weight",
|
296 |
+
"27.to_v_ip.weight",
|
297 |
+
"29.to_k_ip.weight",
|
298 |
+
"29.to_v_ip.weight",
|
299 |
+
"31.to_k_ip.weight",
|
300 |
+
"31.to_v_ip.weight",
|
301 |
+
]
|
302 |
+
|
303 |
+
UNET2IPAadapter_Keys_MAPIING = {
|
304 |
+
k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
|
305 |
+
}
|
306 |
+
|
307 |
+
|
308 |
+
def update_unet_ip_adapter_cross_attn_param(
|
309 |
+
unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
|
310 |
+
) -> None:
|
311 |
+
"""use independent ip_adapter attn 中的 to_k, to_v in unet
|
312 |
+
ip_adapter: dict whose keys are ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']
|
313 |
+
|
314 |
+
|
315 |
+
Args:
|
316 |
+
unet (UNet3DConditionModel): _description_
|
317 |
+
ip_adapter_state_dict (Dict): _description_
|
318 |
+
"""
|
319 |
+
unet_spatial_cross_atnns = unet.spatial_cross_attns[0]
|
320 |
+
unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns}
|
321 |
+
for i, (unet_key_more, ip_adapter_key) in enumerate(
|
322 |
+
UNET2IPAadapter_Keys_MAPIING.items()
|
323 |
+
):
|
324 |
+
ip_adapter_value = ip_adapter_state_dict[ip_adapter_key]
|
325 |
+
unet_key_more_spit = unet_key_more.split(".")
|
326 |
+
unet_key = ".".join(unet_key_more_spit[:-3])
|
327 |
+
suffix = ".".join(unet_key_more_spit[-3:])
|
328 |
+
logger.debug(
|
329 |
+
f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}",
|
330 |
+
)
|
331 |
+
if "to_k" in suffix:
|
332 |
+
with torch.no_grad():
|
333 |
+
unet_spatial_cross_atnns_dct[unet_key].to_k_ip.weight.copy_(
|
334 |
+
ip_adapter_value.data
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
with torch.no_grad():
|
338 |
+
unet_spatial_cross_atnns_dct[unet_key].to_v_ip.weight.copy_(
|
339 |
+
ip_adapter_value.data
|
340 |
+
)
|
musev/models/referencenet.py
ADDED
@@ -0,0 +1,1216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
import logging
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor
|
22 |
+
from einops import rearrange, repeat
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import xformers
|
26 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
27 |
+
from diffusers.models.unet_2d_condition import (
|
28 |
+
UNet2DConditionModel,
|
29 |
+
UNet2DConditionOutput,
|
30 |
+
)
|
31 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
32 |
+
from diffusers.utils.constants import USE_PEFT_BACKEND
|
33 |
+
from diffusers.utils.deprecation_utils import deprecate
|
34 |
+
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
|
35 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
36 |
+
from diffusers.models.modeling_utils import ModelMixin, load_state_dict
|
37 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
38 |
+
from diffusers.utils import (
|
39 |
+
USE_PEFT_BACKEND,
|
40 |
+
BaseOutput,
|
41 |
+
deprecate,
|
42 |
+
scale_lora_layers,
|
43 |
+
unscale_lora_layers,
|
44 |
+
)
|
45 |
+
from diffusers.models.activations import get_activation
|
46 |
+
from diffusers.models.attention_processor import (
|
47 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
48 |
+
CROSS_ATTENTION_PROCESSORS,
|
49 |
+
AttentionProcessor,
|
50 |
+
AttnAddedKVProcessor,
|
51 |
+
AttnProcessor,
|
52 |
+
)
|
53 |
+
from diffusers.models.embeddings import (
|
54 |
+
GaussianFourierProjection,
|
55 |
+
ImageHintTimeEmbedding,
|
56 |
+
ImageProjection,
|
57 |
+
ImageTimeEmbedding,
|
58 |
+
PositionNet,
|
59 |
+
TextImageProjection,
|
60 |
+
TextImageTimeEmbedding,
|
61 |
+
TextTimeEmbedding,
|
62 |
+
TimestepEmbedding,
|
63 |
+
Timesteps,
|
64 |
+
)
|
65 |
+
from diffusers.models.modeling_utils import ModelMixin
|
66 |
+
|
67 |
+
|
68 |
+
from ..data.data_util import align_repeat_tensor_single_dim
|
69 |
+
from .unet_3d_condition import UNet3DConditionModel
|
70 |
+
from .attention import BasicTransformerBlock, IPAttention
|
71 |
+
from .unet_2d_blocks import (
|
72 |
+
UNetMidBlock2D,
|
73 |
+
UNetMidBlock2DCrossAttn,
|
74 |
+
UNetMidBlock2DSimpleCrossAttn,
|
75 |
+
get_down_block,
|
76 |
+
get_up_block,
|
77 |
+
)
|
78 |
+
|
79 |
+
from . import Model_Register
|
80 |
+
|
81 |
+
|
82 |
+
logger = logging.getLogger(__name__)
|
83 |
+
|
84 |
+
|
85 |
+
@Model_Register.register
|
86 |
+
class ReferenceNet2D(UNet2DConditionModel, nn.Module):
|
87 |
+
"""继承 UNet2DConditionModel. 新增功能,类似controlnet 返回模型中间特征,用于后续作用
|
88 |
+
Inherit Unet2DConditionModel. Add new functions, similar to controlnet, return the intermediate features of the model for subsequent effects
|
89 |
+
Args:
|
90 |
+
UNet2DConditionModel (_type_): _description_
|
91 |
+
"""
|
92 |
+
|
93 |
+
_supports_gradient_checkpointing = True
|
94 |
+
print_idx = 0
|
95 |
+
|
96 |
+
@register_to_config
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
sample_size: int | None = None,
|
100 |
+
in_channels: int = 4,
|
101 |
+
out_channels: int = 4,
|
102 |
+
center_input_sample: bool = False,
|
103 |
+
flip_sin_to_cos: bool = True,
|
104 |
+
freq_shift: int = 0,
|
105 |
+
down_block_types: Tuple[str] = (
|
106 |
+
"CrossAttnDownBlock2D",
|
107 |
+
"CrossAttnDownBlock2D",
|
108 |
+
"CrossAttnDownBlock2D",
|
109 |
+
"DownBlock2D",
|
110 |
+
),
|
111 |
+
mid_block_type: str | None = "UNetMidBlock2DCrossAttn",
|
112 |
+
up_block_types: Tuple[str] = (
|
113 |
+
"UpBlock2D",
|
114 |
+
"CrossAttnUpBlock2D",
|
115 |
+
"CrossAttnUpBlock2D",
|
116 |
+
"CrossAttnUpBlock2D",
|
117 |
+
),
|
118 |
+
only_cross_attention: bool | Tuple[bool] = False,
|
119 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
120 |
+
layers_per_block: int | Tuple[int] = 2,
|
121 |
+
downsample_padding: int = 1,
|
122 |
+
mid_block_scale_factor: float = 1,
|
123 |
+
dropout: float = 0,
|
124 |
+
act_fn: str = "silu",
|
125 |
+
norm_num_groups: int | None = 32,
|
126 |
+
norm_eps: float = 0.00001,
|
127 |
+
cross_attention_dim: int | Tuple[int] = 1280,
|
128 |
+
transformer_layers_per_block: int | Tuple[int] | Tuple[Tuple] = 1,
|
129 |
+
reverse_transformer_layers_per_block: Tuple[Tuple[int]] | None = None,
|
130 |
+
encoder_hid_dim: int | None = None,
|
131 |
+
encoder_hid_dim_type: str | None = None,
|
132 |
+
attention_head_dim: int | Tuple[int] = 8,
|
133 |
+
num_attention_heads: int | Tuple[int] | None = None,
|
134 |
+
dual_cross_attention: bool = False,
|
135 |
+
use_linear_projection: bool = False,
|
136 |
+
class_embed_type: str | None = None,
|
137 |
+
addition_embed_type: str | None = None,
|
138 |
+
addition_time_embed_dim: int | None = None,
|
139 |
+
num_class_embeds: int | None = None,
|
140 |
+
upcast_attention: bool = False,
|
141 |
+
resnet_time_scale_shift: str = "default",
|
142 |
+
resnet_skip_time_act: bool = False,
|
143 |
+
resnet_out_scale_factor: int = 1,
|
144 |
+
time_embedding_type: str = "positional",
|
145 |
+
time_embedding_dim: int | None = None,
|
146 |
+
time_embedding_act_fn: str | None = None,
|
147 |
+
timestep_post_act: str | None = None,
|
148 |
+
time_cond_proj_dim: int | None = None,
|
149 |
+
conv_in_kernel: int = 3,
|
150 |
+
conv_out_kernel: int = 3,
|
151 |
+
projection_class_embeddings_input_dim: int | None = None,
|
152 |
+
attention_type: str = "default",
|
153 |
+
class_embeddings_concat: bool = False,
|
154 |
+
mid_block_only_cross_attention: bool | None = None,
|
155 |
+
cross_attention_norm: str | None = None,
|
156 |
+
addition_embed_type_num_heads=64,
|
157 |
+
need_self_attn_block_embs: bool = False,
|
158 |
+
need_block_embs: bool = False,
|
159 |
+
):
|
160 |
+
super().__init__()
|
161 |
+
|
162 |
+
self.sample_size = sample_size
|
163 |
+
|
164 |
+
if num_attention_heads is not None:
|
165 |
+
raise ValueError(
|
166 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
167 |
+
)
|
168 |
+
|
169 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
170 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
171 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
172 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
173 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
174 |
+
# which is why we correct for the naming here.
|
175 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
176 |
+
|
177 |
+
# Check inputs
|
178 |
+
if len(down_block_types) != len(up_block_types):
|
179 |
+
raise ValueError(
|
180 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
181 |
+
)
|
182 |
+
|
183 |
+
if len(block_out_channels) != len(down_block_types):
|
184 |
+
raise ValueError(
|
185 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
186 |
+
)
|
187 |
+
|
188 |
+
if not isinstance(only_cross_attention, bool) and len(
|
189 |
+
only_cross_attention
|
190 |
+
) != len(down_block_types):
|
191 |
+
raise ValueError(
|
192 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
193 |
+
)
|
194 |
+
|
195 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
196 |
+
down_block_types
|
197 |
+
):
|
198 |
+
raise ValueError(
|
199 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
200 |
+
)
|
201 |
+
|
202 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
|
203 |
+
down_block_types
|
204 |
+
):
|
205 |
+
raise ValueError(
|
206 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
207 |
+
)
|
208 |
+
|
209 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
|
210 |
+
down_block_types
|
211 |
+
):
|
212 |
+
raise ValueError(
|
213 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
214 |
+
)
|
215 |
+
|
216 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
|
217 |
+
down_block_types
|
218 |
+
):
|
219 |
+
raise ValueError(
|
220 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
221 |
+
)
|
222 |
+
if (
|
223 |
+
isinstance(transformer_layers_per_block, list)
|
224 |
+
and reverse_transformer_layers_per_block is None
|
225 |
+
):
|
226 |
+
for layer_number_per_block in transformer_layers_per_block:
|
227 |
+
if isinstance(layer_number_per_block, list):
|
228 |
+
raise ValueError(
|
229 |
+
"Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
|
230 |
+
)
|
231 |
+
|
232 |
+
# input
|
233 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
234 |
+
self.conv_in = nn.Conv2d(
|
235 |
+
in_channels,
|
236 |
+
block_out_channels[0],
|
237 |
+
kernel_size=conv_in_kernel,
|
238 |
+
padding=conv_in_padding,
|
239 |
+
)
|
240 |
+
|
241 |
+
# time
|
242 |
+
if time_embedding_type == "fourier":
|
243 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
244 |
+
if time_embed_dim % 2 != 0:
|
245 |
+
raise ValueError(
|
246 |
+
f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
|
247 |
+
)
|
248 |
+
self.time_proj = GaussianFourierProjection(
|
249 |
+
time_embed_dim // 2,
|
250 |
+
set_W_to_weight=False,
|
251 |
+
log=False,
|
252 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
253 |
+
)
|
254 |
+
timestep_input_dim = time_embed_dim
|
255 |
+
elif time_embedding_type == "positional":
|
256 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
257 |
+
|
258 |
+
self.time_proj = Timesteps(
|
259 |
+
block_out_channels[0], flip_sin_to_cos, freq_shift
|
260 |
+
)
|
261 |
+
timestep_input_dim = block_out_channels[0]
|
262 |
+
else:
|
263 |
+
raise ValueError(
|
264 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
265 |
+
)
|
266 |
+
|
267 |
+
self.time_embedding = TimestepEmbedding(
|
268 |
+
timestep_input_dim,
|
269 |
+
time_embed_dim,
|
270 |
+
act_fn=act_fn,
|
271 |
+
post_act_fn=timestep_post_act,
|
272 |
+
cond_proj_dim=time_cond_proj_dim,
|
273 |
+
)
|
274 |
+
|
275 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
276 |
+
encoder_hid_dim_type = "text_proj"
|
277 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
278 |
+
logger.info(
|
279 |
+
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
|
280 |
+
)
|
281 |
+
|
282 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
283 |
+
raise ValueError(
|
284 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
285 |
+
)
|
286 |
+
|
287 |
+
if encoder_hid_dim_type == "text_proj":
|
288 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
289 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
290 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
291 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
292 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
293 |
+
self.encoder_hid_proj = TextImageProjection(
|
294 |
+
text_embed_dim=encoder_hid_dim,
|
295 |
+
image_embed_dim=cross_attention_dim,
|
296 |
+
cross_attention_dim=cross_attention_dim,
|
297 |
+
)
|
298 |
+
elif encoder_hid_dim_type == "image_proj":
|
299 |
+
# Kandinsky 2.2
|
300 |
+
self.encoder_hid_proj = ImageProjection(
|
301 |
+
image_embed_dim=encoder_hid_dim,
|
302 |
+
cross_attention_dim=cross_attention_dim,
|
303 |
+
)
|
304 |
+
elif encoder_hid_dim_type is not None:
|
305 |
+
raise ValueError(
|
306 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
self.encoder_hid_proj = None
|
310 |
+
|
311 |
+
# class embedding
|
312 |
+
if class_embed_type is None and num_class_embeds is not None:
|
313 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
314 |
+
elif class_embed_type == "timestep":
|
315 |
+
self.class_embedding = TimestepEmbedding(
|
316 |
+
timestep_input_dim, time_embed_dim, act_fn=act_fn
|
317 |
+
)
|
318 |
+
elif class_embed_type == "identity":
|
319 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
320 |
+
elif class_embed_type == "projection":
|
321 |
+
if projection_class_embeddings_input_dim is None:
|
322 |
+
raise ValueError(
|
323 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
324 |
+
)
|
325 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
326 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
327 |
+
# 2. it projects from an arbitrary input dimension.
|
328 |
+
#
|
329 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
330 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
331 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
332 |
+
self.class_embedding = TimestepEmbedding(
|
333 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
334 |
+
)
|
335 |
+
elif class_embed_type == "simple_projection":
|
336 |
+
if projection_class_embeddings_input_dim is None:
|
337 |
+
raise ValueError(
|
338 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
339 |
+
)
|
340 |
+
self.class_embedding = nn.Linear(
|
341 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
self.class_embedding = None
|
345 |
+
|
346 |
+
if addition_embed_type == "text":
|
347 |
+
if encoder_hid_dim is not None:
|
348 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
349 |
+
else:
|
350 |
+
text_time_embedding_from_dim = cross_attention_dim
|
351 |
+
|
352 |
+
self.add_embedding = TextTimeEmbedding(
|
353 |
+
text_time_embedding_from_dim,
|
354 |
+
time_embed_dim,
|
355 |
+
num_heads=addition_embed_type_num_heads,
|
356 |
+
)
|
357 |
+
elif addition_embed_type == "text_image":
|
358 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
359 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
360 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
361 |
+
self.add_embedding = TextImageTimeEmbedding(
|
362 |
+
text_embed_dim=cross_attention_dim,
|
363 |
+
image_embed_dim=cross_attention_dim,
|
364 |
+
time_embed_dim=time_embed_dim,
|
365 |
+
)
|
366 |
+
elif addition_embed_type == "text_time":
|
367 |
+
self.add_time_proj = Timesteps(
|
368 |
+
addition_time_embed_dim, flip_sin_to_cos, freq_shift
|
369 |
+
)
|
370 |
+
self.add_embedding = TimestepEmbedding(
|
371 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
372 |
+
)
|
373 |
+
elif addition_embed_type == "image":
|
374 |
+
# Kandinsky 2.2
|
375 |
+
self.add_embedding = ImageTimeEmbedding(
|
376 |
+
image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
|
377 |
+
)
|
378 |
+
elif addition_embed_type == "image_hint":
|
379 |
+
# Kandinsky 2.2 ControlNet
|
380 |
+
self.add_embedding = ImageHintTimeEmbedding(
|
381 |
+
image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
|
382 |
+
)
|
383 |
+
elif addition_embed_type is not None:
|
384 |
+
raise ValueError(
|
385 |
+
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
|
386 |
+
)
|
387 |
+
|
388 |
+
if time_embedding_act_fn is None:
|
389 |
+
self.time_embed_act = None
|
390 |
+
else:
|
391 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
392 |
+
|
393 |
+
self.down_blocks = nn.ModuleList([])
|
394 |
+
self.up_blocks = nn.ModuleList([])
|
395 |
+
|
396 |
+
if isinstance(only_cross_attention, bool):
|
397 |
+
if mid_block_only_cross_attention is None:
|
398 |
+
mid_block_only_cross_attention = only_cross_attention
|
399 |
+
|
400 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
401 |
+
|
402 |
+
if mid_block_only_cross_attention is None:
|
403 |
+
mid_block_only_cross_attention = False
|
404 |
+
|
405 |
+
if isinstance(num_attention_heads, int):
|
406 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
407 |
+
|
408 |
+
if isinstance(attention_head_dim, int):
|
409 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
410 |
+
|
411 |
+
if isinstance(cross_attention_dim, int):
|
412 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
413 |
+
|
414 |
+
if isinstance(layers_per_block, int):
|
415 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
416 |
+
|
417 |
+
if isinstance(transformer_layers_per_block, int):
|
418 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(
|
419 |
+
down_block_types
|
420 |
+
)
|
421 |
+
|
422 |
+
if class_embeddings_concat:
|
423 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
424 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
425 |
+
# regular time embeddings
|
426 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
427 |
+
else:
|
428 |
+
blocks_time_embed_dim = time_embed_dim
|
429 |
+
|
430 |
+
# down
|
431 |
+
output_channel = block_out_channels[0]
|
432 |
+
for i, down_block_type in enumerate(down_block_types):
|
433 |
+
input_channel = output_channel
|
434 |
+
output_channel = block_out_channels[i]
|
435 |
+
is_final_block = i == len(block_out_channels) - 1
|
436 |
+
|
437 |
+
down_block = get_down_block(
|
438 |
+
down_block_type,
|
439 |
+
num_layers=layers_per_block[i],
|
440 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
441 |
+
in_channels=input_channel,
|
442 |
+
out_channels=output_channel,
|
443 |
+
temb_channels=blocks_time_embed_dim,
|
444 |
+
add_downsample=not is_final_block,
|
445 |
+
resnet_eps=norm_eps,
|
446 |
+
resnet_act_fn=act_fn,
|
447 |
+
resnet_groups=norm_num_groups,
|
448 |
+
cross_attention_dim=cross_attention_dim[i],
|
449 |
+
num_attention_heads=num_attention_heads[i],
|
450 |
+
downsample_padding=downsample_padding,
|
451 |
+
dual_cross_attention=dual_cross_attention,
|
452 |
+
use_linear_projection=use_linear_projection,
|
453 |
+
only_cross_attention=only_cross_attention[i],
|
454 |
+
upcast_attention=upcast_attention,
|
455 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
456 |
+
attention_type=attention_type,
|
457 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
458 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
459 |
+
cross_attention_norm=cross_attention_norm,
|
460 |
+
attention_head_dim=attention_head_dim[i]
|
461 |
+
if attention_head_dim[i] is not None
|
462 |
+
else output_channel,
|
463 |
+
dropout=dropout,
|
464 |
+
)
|
465 |
+
self.down_blocks.append(down_block)
|
466 |
+
|
467 |
+
# mid
|
468 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
469 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
470 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
471 |
+
in_channels=block_out_channels[-1],
|
472 |
+
temb_channels=blocks_time_embed_dim,
|
473 |
+
dropout=dropout,
|
474 |
+
resnet_eps=norm_eps,
|
475 |
+
resnet_act_fn=act_fn,
|
476 |
+
output_scale_factor=mid_block_scale_factor,
|
477 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
478 |
+
cross_attention_dim=cross_attention_dim[-1],
|
479 |
+
num_attention_heads=num_attention_heads[-1],
|
480 |
+
resnet_groups=norm_num_groups,
|
481 |
+
dual_cross_attention=dual_cross_attention,
|
482 |
+
use_linear_projection=use_linear_projection,
|
483 |
+
upcast_attention=upcast_attention,
|
484 |
+
attention_type=attention_type,
|
485 |
+
)
|
486 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
487 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
488 |
+
in_channels=block_out_channels[-1],
|
489 |
+
temb_channels=blocks_time_embed_dim,
|
490 |
+
dropout=dropout,
|
491 |
+
resnet_eps=norm_eps,
|
492 |
+
resnet_act_fn=act_fn,
|
493 |
+
output_scale_factor=mid_block_scale_factor,
|
494 |
+
cross_attention_dim=cross_attention_dim[-1],
|
495 |
+
attention_head_dim=attention_head_dim[-1],
|
496 |
+
resnet_groups=norm_num_groups,
|
497 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
498 |
+
skip_time_act=resnet_skip_time_act,
|
499 |
+
only_cross_attention=mid_block_only_cross_attention,
|
500 |
+
cross_attention_norm=cross_attention_norm,
|
501 |
+
)
|
502 |
+
elif mid_block_type == "UNetMidBlock2D":
|
503 |
+
self.mid_block = UNetMidBlock2D(
|
504 |
+
in_channels=block_out_channels[-1],
|
505 |
+
temb_channels=blocks_time_embed_dim,
|
506 |
+
dropout=dropout,
|
507 |
+
num_layers=0,
|
508 |
+
resnet_eps=norm_eps,
|
509 |
+
resnet_act_fn=act_fn,
|
510 |
+
output_scale_factor=mid_block_scale_factor,
|
511 |
+
resnet_groups=norm_num_groups,
|
512 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
513 |
+
add_attention=False,
|
514 |
+
)
|
515 |
+
elif mid_block_type is None:
|
516 |
+
self.mid_block = None
|
517 |
+
else:
|
518 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
519 |
+
|
520 |
+
# count how many layers upsample the images
|
521 |
+
self.num_upsamplers = 0
|
522 |
+
|
523 |
+
# up
|
524 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
525 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
526 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
527 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
528 |
+
reversed_transformer_layers_per_block = (
|
529 |
+
list(reversed(transformer_layers_per_block))
|
530 |
+
if reverse_transformer_layers_per_block is None
|
531 |
+
else reverse_transformer_layers_per_block
|
532 |
+
)
|
533 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
534 |
+
|
535 |
+
output_channel = reversed_block_out_channels[0]
|
536 |
+
for i, up_block_type in enumerate(up_block_types):
|
537 |
+
is_final_block = i == len(block_out_channels) - 1
|
538 |
+
|
539 |
+
prev_output_channel = output_channel
|
540 |
+
output_channel = reversed_block_out_channels[i]
|
541 |
+
input_channel = reversed_block_out_channels[
|
542 |
+
min(i + 1, len(block_out_channels) - 1)
|
543 |
+
]
|
544 |
+
|
545 |
+
# add upsample block for all BUT final layer
|
546 |
+
if not is_final_block:
|
547 |
+
add_upsample = True
|
548 |
+
self.num_upsamplers += 1
|
549 |
+
else:
|
550 |
+
add_upsample = False
|
551 |
+
|
552 |
+
up_block = get_up_block(
|
553 |
+
up_block_type,
|
554 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
555 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
556 |
+
in_channels=input_channel,
|
557 |
+
out_channels=output_channel,
|
558 |
+
prev_output_channel=prev_output_channel,
|
559 |
+
temb_channels=blocks_time_embed_dim,
|
560 |
+
add_upsample=add_upsample,
|
561 |
+
resnet_eps=norm_eps,
|
562 |
+
resnet_act_fn=act_fn,
|
563 |
+
resolution_idx=i,
|
564 |
+
resnet_groups=norm_num_groups,
|
565 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
566 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
567 |
+
dual_cross_attention=dual_cross_attention,
|
568 |
+
use_linear_projection=use_linear_projection,
|
569 |
+
only_cross_attention=only_cross_attention[i],
|
570 |
+
upcast_attention=upcast_attention,
|
571 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
572 |
+
attention_type=attention_type,
|
573 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
574 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
575 |
+
cross_attention_norm=cross_attention_norm,
|
576 |
+
attention_head_dim=attention_head_dim[i]
|
577 |
+
if attention_head_dim[i] is not None
|
578 |
+
else output_channel,
|
579 |
+
dropout=dropout,
|
580 |
+
)
|
581 |
+
self.up_blocks.append(up_block)
|
582 |
+
prev_output_channel = output_channel
|
583 |
+
|
584 |
+
# out
|
585 |
+
if norm_num_groups is not None:
|
586 |
+
self.conv_norm_out = nn.GroupNorm(
|
587 |
+
num_channels=block_out_channels[0],
|
588 |
+
num_groups=norm_num_groups,
|
589 |
+
eps=norm_eps,
|
590 |
+
)
|
591 |
+
|
592 |
+
self.conv_act = get_activation(act_fn)
|
593 |
+
|
594 |
+
else:
|
595 |
+
self.conv_norm_out = None
|
596 |
+
self.conv_act = None
|
597 |
+
|
598 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
599 |
+
self.conv_out = nn.Conv2d(
|
600 |
+
block_out_channels[0],
|
601 |
+
out_channels,
|
602 |
+
kernel_size=conv_out_kernel,
|
603 |
+
padding=conv_out_padding,
|
604 |
+
)
|
605 |
+
|
606 |
+
if attention_type in ["gated", "gated-text-image"]:
|
607 |
+
positive_len = 768
|
608 |
+
if isinstance(cross_attention_dim, int):
|
609 |
+
positive_len = cross_attention_dim
|
610 |
+
elif isinstance(cross_attention_dim, tuple) or isinstance(
|
611 |
+
cross_attention_dim, list
|
612 |
+
):
|
613 |
+
positive_len = cross_attention_dim[0]
|
614 |
+
|
615 |
+
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
616 |
+
self.position_net = PositionNet(
|
617 |
+
positive_len=positive_len,
|
618 |
+
out_dim=cross_attention_dim,
|
619 |
+
feature_type=feature_type,
|
620 |
+
)
|
621 |
+
self.need_block_embs = need_block_embs
|
622 |
+
self.need_self_attn_block_embs = need_self_attn_block_embs
|
623 |
+
|
624 |
+
# only use referencenet soma layers, other layers set None
|
625 |
+
self.conv_norm_out = None
|
626 |
+
self.conv_act = None
|
627 |
+
self.conv_out = None
|
628 |
+
|
629 |
+
self.up_blocks[-1].attentions[-1].proj_out = None
|
630 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn1 = None
|
631 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn2 = None
|
632 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm2 = None
|
633 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].ff = None
|
634 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm3 = None
|
635 |
+
if not self.need_self_attn_block_embs:
|
636 |
+
self.up_blocks = None
|
637 |
+
|
638 |
+
self.insert_spatial_self_attn_idx()
|
639 |
+
|
640 |
+
def forward(
|
641 |
+
self,
|
642 |
+
sample: torch.FloatTensor,
|
643 |
+
timestep: Union[torch.Tensor, float, int],
|
644 |
+
encoder_hidden_states: torch.Tensor,
|
645 |
+
class_labels: Optional[torch.Tensor] = None,
|
646 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
647 |
+
attention_mask: Optional[torch.Tensor] = None,
|
648 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
649 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
650 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
651 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
652 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
653 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
654 |
+
return_dict: bool = True,
|
655 |
+
# update new paramestes start
|
656 |
+
num_frames: int = None,
|
657 |
+
return_ndim: int = 5,
|
658 |
+
# update new paramestes end
|
659 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
660 |
+
r"""
|
661 |
+
The [`UNet2DConditionModel`] forward method.
|
662 |
+
|
663 |
+
Args:
|
664 |
+
sample (`torch.FloatTensor`):
|
665 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
666 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
667 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
668 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
669 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
670 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
671 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
672 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
673 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
674 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
675 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
676 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
677 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
678 |
+
cross_attention_kwargs (`dict`, *optional*):
|
679 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
680 |
+
`self.processor` in
|
681 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
682 |
+
added_cond_kwargs: (`dict`, *optional*):
|
683 |
+
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
684 |
+
are passed along to the UNet blocks.
|
685 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
686 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
687 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
688 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
689 |
+
encoder_attention_mask (`torch.Tensor`):
|
690 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
691 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
692 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
693 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
694 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
695 |
+
tuple.
|
696 |
+
cross_attention_kwargs (`dict`, *optional*):
|
697 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
698 |
+
added_cond_kwargs: (`dict`, *optional*):
|
699 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
700 |
+
are passed along to the UNet blocks.
|
701 |
+
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
702 |
+
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
703 |
+
example from ControlNet side model(s)
|
704 |
+
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
705 |
+
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
706 |
+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
707 |
+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
708 |
+
|
709 |
+
Returns:
|
710 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
711 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
712 |
+
a `tuple` is returned where the first element is the sample tensor.
|
713 |
+
"""
|
714 |
+
|
715 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
716 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
717 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
718 |
+
# on the fly if necessary.
|
719 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
720 |
+
|
721 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
722 |
+
forward_upsample_size = False
|
723 |
+
upsample_size = None
|
724 |
+
|
725 |
+
for dim in sample.shape[-2:]:
|
726 |
+
if dim % default_overall_up_factor != 0:
|
727 |
+
# Forward upsample size to force interpolation output size.
|
728 |
+
forward_upsample_size = True
|
729 |
+
break
|
730 |
+
|
731 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
732 |
+
# expects mask of shape:
|
733 |
+
# [batch, key_tokens]
|
734 |
+
# adds singleton query_tokens dimension:
|
735 |
+
# [batch, 1, key_tokens]
|
736 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
737 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
738 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
739 |
+
if attention_mask is not None:
|
740 |
+
# assume that mask is expressed as:
|
741 |
+
# (1 = keep, 0 = discard)
|
742 |
+
# convert mask into a bias that can be added to attention scores:
|
743 |
+
# (keep = +0, discard = -10000.0)
|
744 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
745 |
+
attention_mask = attention_mask.unsqueeze(1)
|
746 |
+
|
747 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
748 |
+
if encoder_attention_mask is not None:
|
749 |
+
encoder_attention_mask = (
|
750 |
+
1 - encoder_attention_mask.to(sample.dtype)
|
751 |
+
) * -10000.0
|
752 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
753 |
+
|
754 |
+
# 0. center input if necessary
|
755 |
+
if self.config.center_input_sample:
|
756 |
+
sample = 2 * sample - 1.0
|
757 |
+
|
758 |
+
# 1. time
|
759 |
+
timesteps = timestep
|
760 |
+
if not torch.is_tensor(timesteps):
|
761 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
762 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
763 |
+
is_mps = sample.device.type == "mps"
|
764 |
+
if isinstance(timestep, float):
|
765 |
+
dtype = torch.float32 if is_mps else torch.float64
|
766 |
+
else:
|
767 |
+
dtype = torch.int32 if is_mps else torch.int64
|
768 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
769 |
+
elif len(timesteps.shape) == 0:
|
770 |
+
timesteps = timesteps[None].to(sample.device)
|
771 |
+
|
772 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
773 |
+
timesteps = timesteps.expand(sample.shape[0])
|
774 |
+
|
775 |
+
t_emb = self.time_proj(timesteps)
|
776 |
+
|
777 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
778 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
779 |
+
# there might be better ways to encapsulate this.
|
780 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
781 |
+
|
782 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
783 |
+
aug_emb = None
|
784 |
+
|
785 |
+
if self.class_embedding is not None:
|
786 |
+
if class_labels is None:
|
787 |
+
raise ValueError(
|
788 |
+
"class_labels should be provided when num_class_embeds > 0"
|
789 |
+
)
|
790 |
+
|
791 |
+
if self.config.class_embed_type == "timestep":
|
792 |
+
class_labels = self.time_proj(class_labels)
|
793 |
+
|
794 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
795 |
+
# there might be better ways to encapsulate this.
|
796 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
797 |
+
|
798 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
799 |
+
|
800 |
+
if self.config.class_embeddings_concat:
|
801 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
802 |
+
else:
|
803 |
+
emb = emb + class_emb
|
804 |
+
|
805 |
+
if self.config.addition_embed_type == "text":
|
806 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
807 |
+
elif self.config.addition_embed_type == "text_image":
|
808 |
+
# Kandinsky 2.1 - style
|
809 |
+
if "image_embeds" not in added_cond_kwargs:
|
810 |
+
raise ValueError(
|
811 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
812 |
+
)
|
813 |
+
|
814 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
815 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
816 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
817 |
+
elif self.config.addition_embed_type == "text_time":
|
818 |
+
# SDXL - style
|
819 |
+
if "text_embeds" not in added_cond_kwargs:
|
820 |
+
raise ValueError(
|
821 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
822 |
+
)
|
823 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
824 |
+
if "time_ids" not in added_cond_kwargs:
|
825 |
+
raise ValueError(
|
826 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
827 |
+
)
|
828 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
829 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
830 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
831 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
832 |
+
add_embeds = add_embeds.to(emb.dtype)
|
833 |
+
aug_emb = self.add_embedding(add_embeds)
|
834 |
+
elif self.config.addition_embed_type == "image":
|
835 |
+
# Kandinsky 2.2 - style
|
836 |
+
if "image_embeds" not in added_cond_kwargs:
|
837 |
+
raise ValueError(
|
838 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
839 |
+
)
|
840 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
841 |
+
aug_emb = self.add_embedding(image_embs)
|
842 |
+
elif self.config.addition_embed_type == "image_hint":
|
843 |
+
# Kandinsky 2.2 - style
|
844 |
+
if (
|
845 |
+
"image_embeds" not in added_cond_kwargs
|
846 |
+
or "hint" not in added_cond_kwargs
|
847 |
+
):
|
848 |
+
raise ValueError(
|
849 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
850 |
+
)
|
851 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
852 |
+
hint = added_cond_kwargs.get("hint")
|
853 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
854 |
+
sample = torch.cat([sample, hint], dim=1)
|
855 |
+
|
856 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
857 |
+
|
858 |
+
if self.time_embed_act is not None:
|
859 |
+
emb = self.time_embed_act(emb)
|
860 |
+
|
861 |
+
if (
|
862 |
+
self.encoder_hid_proj is not None
|
863 |
+
and self.config.encoder_hid_dim_type == "text_proj"
|
864 |
+
):
|
865 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
866 |
+
elif (
|
867 |
+
self.encoder_hid_proj is not None
|
868 |
+
and self.config.encoder_hid_dim_type == "text_image_proj"
|
869 |
+
):
|
870 |
+
# Kadinsky 2.1 - style
|
871 |
+
if "image_embeds" not in added_cond_kwargs:
|
872 |
+
raise ValueError(
|
873 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
874 |
+
)
|
875 |
+
|
876 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
877 |
+
encoder_hidden_states = self.encoder_hid_proj(
|
878 |
+
encoder_hidden_states, image_embeds
|
879 |
+
)
|
880 |
+
elif (
|
881 |
+
self.encoder_hid_proj is not None
|
882 |
+
and self.config.encoder_hid_dim_type == "image_proj"
|
883 |
+
):
|
884 |
+
# Kandinsky 2.2 - style
|
885 |
+
if "image_embeds" not in added_cond_kwargs:
|
886 |
+
raise ValueError(
|
887 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
888 |
+
)
|
889 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
890 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
891 |
+
elif (
|
892 |
+
self.encoder_hid_proj is not None
|
893 |
+
and self.config.encoder_hid_dim_type == "ip_image_proj"
|
894 |
+
):
|
895 |
+
if "image_embeds" not in added_cond_kwargs:
|
896 |
+
raise ValueError(
|
897 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
898 |
+
)
|
899 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
900 |
+
image_embeds = self.encoder_hid_proj(image_embeds).to(
|
901 |
+
encoder_hidden_states.dtype
|
902 |
+
)
|
903 |
+
encoder_hidden_states = torch.cat(
|
904 |
+
[encoder_hidden_states, image_embeds], dim=1
|
905 |
+
)
|
906 |
+
|
907 |
+
# need_self_attn_block_embs
|
908 |
+
# 初始化
|
909 |
+
# 或在unet中运算中会不断 append self_attn_blocks_embs,用完需要清理,
|
910 |
+
if self.need_self_attn_block_embs:
|
911 |
+
self_attn_block_embs = [None] * self.self_attn_num
|
912 |
+
else:
|
913 |
+
self_attn_block_embs = None
|
914 |
+
# 2. pre-process
|
915 |
+
sample = self.conv_in(sample)
|
916 |
+
if self.print_idx == 0:
|
917 |
+
logger.debug(f"after conv in sample={sample.mean()}")
|
918 |
+
# 2.5 GLIGEN position net
|
919 |
+
if (
|
920 |
+
cross_attention_kwargs is not None
|
921 |
+
and cross_attention_kwargs.get("gligen", None) is not None
|
922 |
+
):
|
923 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
924 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
925 |
+
cross_attention_kwargs["gligen"] = {
|
926 |
+
"objs": self.position_net(**gligen_args)
|
927 |
+
}
|
928 |
+
|
929 |
+
# 3. down
|
930 |
+
lora_scale = (
|
931 |
+
cross_attention_kwargs.get("scale", 1.0)
|
932 |
+
if cross_attention_kwargs is not None
|
933 |
+
else 1.0
|
934 |
+
)
|
935 |
+
if USE_PEFT_BACKEND:
|
936 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
937 |
+
scale_lora_layers(self, lora_scale)
|
938 |
+
|
939 |
+
is_controlnet = (
|
940 |
+
mid_block_additional_residual is not None
|
941 |
+
and down_block_additional_residuals is not None
|
942 |
+
)
|
943 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
944 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
945 |
+
# maintain backward compatibility for legacy usage, where
|
946 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
947 |
+
# but can only use one or the other
|
948 |
+
if (
|
949 |
+
not is_adapter
|
950 |
+
and mid_block_additional_residual is None
|
951 |
+
and down_block_additional_residuals is not None
|
952 |
+
):
|
953 |
+
deprecate(
|
954 |
+
"T2I should not use down_block_additional_residuals",
|
955 |
+
"1.3.0",
|
956 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
957 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
958 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
959 |
+
standard_warn=False,
|
960 |
+
)
|
961 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
962 |
+
is_adapter = True
|
963 |
+
|
964 |
+
down_block_res_samples = (sample,)
|
965 |
+
for i_downsample_block, downsample_block in enumerate(self.down_blocks):
|
966 |
+
if (
|
967 |
+
hasattr(downsample_block, "has_cross_attention")
|
968 |
+
and downsample_block.has_cross_attention
|
969 |
+
):
|
970 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
971 |
+
additional_residuals = {}
|
972 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
973 |
+
additional_residuals[
|
974 |
+
"additional_residuals"
|
975 |
+
] = down_intrablock_additional_residuals.pop(0)
|
976 |
+
if self.print_idx == 0:
|
977 |
+
logger.debug(
|
978 |
+
f"downsample_block {i_downsample_block} sample={sample.mean()}"
|
979 |
+
)
|
980 |
+
sample, res_samples = downsample_block(
|
981 |
+
hidden_states=sample,
|
982 |
+
temb=emb,
|
983 |
+
encoder_hidden_states=encoder_hidden_states,
|
984 |
+
attention_mask=attention_mask,
|
985 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
986 |
+
encoder_attention_mask=encoder_attention_mask,
|
987 |
+
**additional_residuals,
|
988 |
+
self_attn_block_embs=self_attn_block_embs,
|
989 |
+
)
|
990 |
+
else:
|
991 |
+
sample, res_samples = downsample_block(
|
992 |
+
hidden_states=sample,
|
993 |
+
temb=emb,
|
994 |
+
scale=lora_scale,
|
995 |
+
self_attn_block_embs=self_attn_block_embs,
|
996 |
+
)
|
997 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
998 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
999 |
+
|
1000 |
+
down_block_res_samples += res_samples
|
1001 |
+
|
1002 |
+
if is_controlnet:
|
1003 |
+
new_down_block_res_samples = ()
|
1004 |
+
|
1005 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
1006 |
+
down_block_res_samples, down_block_additional_residuals
|
1007 |
+
):
|
1008 |
+
down_block_res_sample = (
|
1009 |
+
down_block_res_sample + down_block_additional_residual
|
1010 |
+
)
|
1011 |
+
new_down_block_res_samples = new_down_block_res_samples + (
|
1012 |
+
down_block_res_sample,
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
down_block_res_samples = new_down_block_res_samples
|
1016 |
+
|
1017 |
+
# update code start
|
1018 |
+
def reshape_return_emb(tmp_emb):
|
1019 |
+
if return_ndim == 4:
|
1020 |
+
return tmp_emb
|
1021 |
+
elif return_ndim == 5:
|
1022 |
+
return rearrange(tmp_emb, "(b t) c h w-> b c t h w", t=num_frames)
|
1023 |
+
else:
|
1024 |
+
raise ValueError(
|
1025 |
+
f"reshape_emb only support 4, 5 but given {return_ndim}"
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
if self.need_block_embs:
|
1029 |
+
return_down_block_res_samples = [
|
1030 |
+
reshape_return_emb(tmp_emb) for tmp_emb in down_block_res_samples
|
1031 |
+
]
|
1032 |
+
else:
|
1033 |
+
return_down_block_res_samples = None
|
1034 |
+
# update code end
|
1035 |
+
|
1036 |
+
# 4. mid
|
1037 |
+
if self.mid_block is not None:
|
1038 |
+
if (
|
1039 |
+
hasattr(self.mid_block, "has_cross_attention")
|
1040 |
+
and self.mid_block.has_cross_attention
|
1041 |
+
):
|
1042 |
+
sample = self.mid_block(
|
1043 |
+
sample,
|
1044 |
+
emb,
|
1045 |
+
encoder_hidden_states=encoder_hidden_states,
|
1046 |
+
attention_mask=attention_mask,
|
1047 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1048 |
+
encoder_attention_mask=encoder_attention_mask,
|
1049 |
+
self_attn_block_embs=self_attn_block_embs,
|
1050 |
+
)
|
1051 |
+
else:
|
1052 |
+
sample = self.mid_block(sample, emb)
|
1053 |
+
|
1054 |
+
# To support T2I-Adapter-XL
|
1055 |
+
if (
|
1056 |
+
is_adapter
|
1057 |
+
and len(down_intrablock_additional_residuals) > 0
|
1058 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
1059 |
+
):
|
1060 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1061 |
+
|
1062 |
+
if is_controlnet:
|
1063 |
+
sample = sample + mid_block_additional_residual
|
1064 |
+
|
1065 |
+
if self.need_block_embs:
|
1066 |
+
return_mid_block_res_samples = reshape_return_emb(sample)
|
1067 |
+
logger.debug(
|
1068 |
+
f"return_mid_block_res_samples, is_leaf={return_mid_block_res_samples.is_leaf}, requires_grad={return_mid_block_res_samples.requires_grad}"
|
1069 |
+
)
|
1070 |
+
else:
|
1071 |
+
return_mid_block_res_samples = None
|
1072 |
+
|
1073 |
+
if self.up_blocks is not None:
|
1074 |
+
# update code end
|
1075 |
+
|
1076 |
+
# 5. up
|
1077 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1078 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1079 |
+
|
1080 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1081 |
+
down_block_res_samples = down_block_res_samples[
|
1082 |
+
: -len(upsample_block.resnets)
|
1083 |
+
]
|
1084 |
+
|
1085 |
+
# if we have not reached the final block and need to forward the
|
1086 |
+
# upsample size, we do it here
|
1087 |
+
if not is_final_block and forward_upsample_size:
|
1088 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1089 |
+
|
1090 |
+
if (
|
1091 |
+
hasattr(upsample_block, "has_cross_attention")
|
1092 |
+
and upsample_block.has_cross_attention
|
1093 |
+
):
|
1094 |
+
sample = upsample_block(
|
1095 |
+
hidden_states=sample,
|
1096 |
+
temb=emb,
|
1097 |
+
res_hidden_states_tuple=res_samples,
|
1098 |
+
encoder_hidden_states=encoder_hidden_states,
|
1099 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1100 |
+
upsample_size=upsample_size,
|
1101 |
+
attention_mask=attention_mask,
|
1102 |
+
encoder_attention_mask=encoder_attention_mask,
|
1103 |
+
self_attn_block_embs=self_attn_block_embs,
|
1104 |
+
)
|
1105 |
+
else:
|
1106 |
+
sample = upsample_block(
|
1107 |
+
hidden_states=sample,
|
1108 |
+
temb=emb,
|
1109 |
+
res_hidden_states_tuple=res_samples,
|
1110 |
+
upsample_size=upsample_size,
|
1111 |
+
scale=lora_scale,
|
1112 |
+
self_attn_block_embs=self_attn_block_embs,
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
# update code start
|
1116 |
+
if self.need_block_embs or self.need_self_attn_block_embs:
|
1117 |
+
if self_attn_block_embs is not None:
|
1118 |
+
self_attn_block_embs = [
|
1119 |
+
reshape_return_emb(tmp_emb=tmp_emb)
|
1120 |
+
for tmp_emb in self_attn_block_embs
|
1121 |
+
]
|
1122 |
+
self.print_idx += 1
|
1123 |
+
return (
|
1124 |
+
return_down_block_res_samples,
|
1125 |
+
return_mid_block_res_samples,
|
1126 |
+
self_attn_block_embs,
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
if not self.need_block_embs and not self.need_self_attn_block_embs:
|
1130 |
+
# 6. post-process
|
1131 |
+
if self.conv_norm_out:
|
1132 |
+
sample = self.conv_norm_out(sample)
|
1133 |
+
sample = self.conv_act(sample)
|
1134 |
+
sample = self.conv_out(sample)
|
1135 |
+
|
1136 |
+
if USE_PEFT_BACKEND:
|
1137 |
+
# remove `lora_scale` from each PEFT layer
|
1138 |
+
unscale_lora_layers(self, lora_scale)
|
1139 |
+
self.print_idx += 1
|
1140 |
+
if not return_dict:
|
1141 |
+
return (sample,)
|
1142 |
+
|
1143 |
+
return UNet2DConditionOutput(sample=sample)
|
1144 |
+
|
1145 |
+
def insert_spatial_self_attn_idx(self):
|
1146 |
+
attns, basic_transformers = self.spatial_self_attns
|
1147 |
+
self.self_attn_num = len(attns)
|
1148 |
+
for i, (name, layer) in enumerate(attns):
|
1149 |
+
logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}")
|
1150 |
+
if layer is not None:
|
1151 |
+
layer.spatial_self_attn_idx = i
|
1152 |
+
for i, (name, layer) in enumerate(basic_transformers):
|
1153 |
+
logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}")
|
1154 |
+
if layer is not None:
|
1155 |
+
layer.spatial_self_attn_idx = i
|
1156 |
+
|
1157 |
+
@property
|
1158 |
+
def spatial_self_attns(
|
1159 |
+
self,
|
1160 |
+
) -> List[Tuple[str, Attention]]:
|
1161 |
+
attns, spatial_transformers = self.get_self_attns(
|
1162 |
+
include="attentions", exclude="temp_attentions"
|
1163 |
+
)
|
1164 |
+
attns = sorted(attns)
|
1165 |
+
spatial_transformers = sorted(spatial_transformers)
|
1166 |
+
return attns, spatial_transformers
|
1167 |
+
|
1168 |
+
def get_self_attns(
|
1169 |
+
self, include: str = None, exclude: str = None
|
1170 |
+
) -> List[Tuple[str, Attention]]:
|
1171 |
+
r"""
|
1172 |
+
Returns:
|
1173 |
+
`dict` of attention attns: A dictionary containing all attention attns used in the model with
|
1174 |
+
indexed by its weight name.
|
1175 |
+
"""
|
1176 |
+
# set recursively
|
1177 |
+
attns = []
|
1178 |
+
spatial_transformers = []
|
1179 |
+
|
1180 |
+
def fn_recursive_add_attns(
|
1181 |
+
name: str,
|
1182 |
+
module: torch.nn.Module,
|
1183 |
+
attns: List[Tuple[str, Attention]],
|
1184 |
+
spatial_transformers: List[Tuple[str, BasicTransformerBlock]],
|
1185 |
+
):
|
1186 |
+
is_target = False
|
1187 |
+
if isinstance(module, BasicTransformerBlock) and hasattr(module, "attn1"):
|
1188 |
+
is_target = True
|
1189 |
+
if include is not None:
|
1190 |
+
is_target = include in name
|
1191 |
+
if exclude is not None:
|
1192 |
+
is_target = exclude not in name
|
1193 |
+
if is_target:
|
1194 |
+
attns.append([f"{name}.attn1", module.attn1])
|
1195 |
+
spatial_transformers.append([f"{name}", module])
|
1196 |
+
for sub_name, child in module.named_children():
|
1197 |
+
fn_recursive_add_attns(
|
1198 |
+
f"{name}.{sub_name}", child, attns, spatial_transformers
|
1199 |
+
)
|
1200 |
+
|
1201 |
+
return attns
|
1202 |
+
|
1203 |
+
for name, module in self.named_children():
|
1204 |
+
fn_recursive_add_attns(name, module, attns, spatial_transformers)
|
1205 |
+
|
1206 |
+
return attns, spatial_transformers
|
1207 |
+
|
1208 |
+
|
1209 |
+
class ReferenceNet3D(UNet3DConditionModel):
|
1210 |
+
"""继承 UNet3DConditionModel, 用于提取中间emb用于后续作用。
|
1211 |
+
Inherit Unet3DConditionModel, used to extract the middle emb for subsequent actions.
|
1212 |
+
Args:
|
1213 |
+
UNet3DConditionModel (_type_): _description_
|
1214 |
+
"""
|
1215 |
+
|
1216 |
+
pass
|
musev/models/referencenet_loader.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.modeling_utils import load_state_dict
|
31 |
+
from diffusers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
|
36 |
+
from .referencenet import ReferenceNet2D
|
37 |
+
from .unet_loader import update_unet_with_sd
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
+
|
42 |
+
|
43 |
+
def load_referencenet(
|
44 |
+
sd_referencenet_model: Tuple[str, nn.Module],
|
45 |
+
sd_model: nn.Module = None,
|
46 |
+
need_self_attn_block_embs: bool = False,
|
47 |
+
need_block_embs: bool = False,
|
48 |
+
dtype: torch.dtype = torch.float16,
|
49 |
+
cross_attention_dim: int = 768,
|
50 |
+
subfolder: str = "unet",
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Loads the ReferenceNet model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
sd_referencenet_model (Tuple[str, nn.Module] or str): The pretrained ReferenceNet model or the path to the model.
|
57 |
+
sd_model (nn.Module, optional): The sd_model to update the ReferenceNet with. Defaults to None.
|
58 |
+
need_self_attn_block_embs (bool, optional): Whether to compute self-attention block embeddings. Defaults to False.
|
59 |
+
need_block_embs (bool, optional): Whether to compute block embeddings. Defaults to False.
|
60 |
+
dtype (torch.dtype, optional): The data type of the tensors. Defaults to torch.float16.
|
61 |
+
cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 768.
|
62 |
+
subfolder (str, optional): The subfolder of the model. Defaults to "unet".
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
nn.Module: The loaded ReferenceNet model.
|
66 |
+
"""
|
67 |
+
|
68 |
+
if isinstance(sd_referencenet_model, str):
|
69 |
+
referencenet = ReferenceNet2D.from_pretrained(
|
70 |
+
sd_referencenet_model,
|
71 |
+
subfolder=subfolder,
|
72 |
+
need_self_attn_block_embs=need_self_attn_block_embs,
|
73 |
+
need_block_embs=need_block_embs,
|
74 |
+
torch_dtype=dtype,
|
75 |
+
cross_attention_dim=cross_attention_dim,
|
76 |
+
)
|
77 |
+
elif isinstance(sd_referencenet_model, nn.Module):
|
78 |
+
referencenet = sd_referencenet_model
|
79 |
+
if sd_model is not None:
|
80 |
+
referencenet = update_unet_with_sd(referencenet, sd_model)
|
81 |
+
return referencenet
|
82 |
+
|
83 |
+
|
84 |
+
def load_referencenet_by_name(
|
85 |
+
model_name: str,
|
86 |
+
sd_referencenet_model: Tuple[str, nn.Module],
|
87 |
+
sd_model: nn.Module = None,
|
88 |
+
cross_attention_dim: int = 768,
|
89 |
+
dtype: torch.dtype = torch.float16,
|
90 |
+
) -> nn.Module:
|
91 |
+
"""通过模型名字 初始化 referencenet,载入预训练参数,
|
92 |
+
如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义
|
93 |
+
init referencenet with model_name.
|
94 |
+
if you want to use pretrained model with simple name, you need to define it here.
|
95 |
+
Args:
|
96 |
+
model_name (str): _description_
|
97 |
+
sd_unet_model (Tuple[str, nn.Module]): _description_
|
98 |
+
sd_model (Tuple[str, nn.Module]): _description_
|
99 |
+
cross_attention_dim (int, optional): _description_. Defaults to 768.
|
100 |
+
dtype (torch.dtype, optional): _description_. Defaults to torch.float16.
|
101 |
+
|
102 |
+
Raises:
|
103 |
+
ValueError: _description_
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
nn.Module: _description_
|
107 |
+
"""
|
108 |
+
if model_name in [
|
109 |
+
"musev_referencenet",
|
110 |
+
]:
|
111 |
+
unet = load_referencenet(
|
112 |
+
sd_referencenet_model=sd_referencenet_model,
|
113 |
+
sd_model=sd_model,
|
114 |
+
cross_attention_dim=cross_attention_dim,
|
115 |
+
dtype=dtype,
|
116 |
+
need_self_attn_block_embs=False,
|
117 |
+
need_block_embs=True,
|
118 |
+
subfolder="referencenet",
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
raise ValueError(
|
122 |
+
f"unsupport model_name={model_name}, only support ReferenceNet_V0_block13, ReferenceNet_V1_block13, ReferenceNet_V2_block13, ReferenceNet_V0_sefattn16"
|
123 |
+
)
|
124 |
+
return unet
|
musev/models/resnet.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py
|
17 |
+
from __future__ import annotations
|
18 |
+
|
19 |
+
from functools import partial
|
20 |
+
from typing import Optional
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from einops import rearrange, repeat
|
26 |
+
|
27 |
+
from diffusers.models.resnet import TemporalConvLayer as DiffusersTemporalConvLayer
|
28 |
+
from ..data.data_util import batch_index_fill, batch_index_select
|
29 |
+
from . import Model_Register
|
30 |
+
|
31 |
+
|
32 |
+
@Model_Register.register
|
33 |
+
class TemporalConvLayer(nn.Module):
|
34 |
+
"""
|
35 |
+
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
36 |
+
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
in_dim,
|
42 |
+
out_dim=None,
|
43 |
+
dropout=0.0,
|
44 |
+
keep_content_condition: bool = False,
|
45 |
+
femb_channels: Optional[int] = None,
|
46 |
+
need_temporal_weight: bool = True,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
out_dim = out_dim or in_dim
|
50 |
+
self.in_dim = in_dim
|
51 |
+
self.out_dim = out_dim
|
52 |
+
self.keep_content_condition = keep_content_condition
|
53 |
+
self.femb_channels = femb_channels
|
54 |
+
self.need_temporal_weight = need_temporal_weight
|
55 |
+
# conv layers
|
56 |
+
self.conv1 = nn.Sequential(
|
57 |
+
nn.GroupNorm(32, in_dim),
|
58 |
+
nn.SiLU(),
|
59 |
+
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
|
60 |
+
)
|
61 |
+
self.conv2 = nn.Sequential(
|
62 |
+
nn.GroupNorm(32, out_dim),
|
63 |
+
nn.SiLU(),
|
64 |
+
nn.Dropout(dropout),
|
65 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
66 |
+
)
|
67 |
+
self.conv3 = nn.Sequential(
|
68 |
+
nn.GroupNorm(32, out_dim),
|
69 |
+
nn.SiLU(),
|
70 |
+
nn.Dropout(dropout),
|
71 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
72 |
+
)
|
73 |
+
self.conv4 = nn.Sequential(
|
74 |
+
nn.GroupNorm(32, out_dim),
|
75 |
+
nn.SiLU(),
|
76 |
+
nn.Dropout(dropout),
|
77 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
78 |
+
)
|
79 |
+
|
80 |
+
# zero out the last layer params,so the conv block is identity
|
81 |
+
# nn.init.zeros_(self.conv4[-1].weight)
|
82 |
+
# nn.init.zeros_(self.conv4[-1].bias)
|
83 |
+
self.temporal_weight = nn.Parameter(
|
84 |
+
torch.tensor(
|
85 |
+
[
|
86 |
+
1e-5,
|
87 |
+
]
|
88 |
+
)
|
89 |
+
) # initialize parameter with 0
|
90 |
+
# zero out the last layer params,so the conv block is identity
|
91 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
92 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
93 |
+
self.skip_temporal_layers = False # Whether to skip temporal layer
|
94 |
+
|
95 |
+
def forward(
|
96 |
+
self,
|
97 |
+
hidden_states,
|
98 |
+
num_frames=1,
|
99 |
+
sample_index: torch.LongTensor = None,
|
100 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
101 |
+
femb: torch.Tensor = None,
|
102 |
+
):
|
103 |
+
if self.skip_temporal_layers is True:
|
104 |
+
return hidden_states
|
105 |
+
hidden_states_dtype = hidden_states.dtype
|
106 |
+
hidden_states = rearrange(
|
107 |
+
hidden_states, "(b t) c h w -> b c t h w", t=num_frames
|
108 |
+
)
|
109 |
+
identity = hidden_states
|
110 |
+
hidden_states = self.conv1(hidden_states)
|
111 |
+
hidden_states = self.conv2(hidden_states)
|
112 |
+
hidden_states = self.conv3(hidden_states)
|
113 |
+
hidden_states = self.conv4(hidden_states)
|
114 |
+
# 保留condition对应的frames,便于保持前序内容帧,提升一致性
|
115 |
+
if self.keep_content_condition:
|
116 |
+
mask = torch.ones_like(hidden_states, device=hidden_states.device)
|
117 |
+
mask = batch_index_fill(
|
118 |
+
mask, dim=2, index=vision_conditon_frames_sample_index, value=0
|
119 |
+
)
|
120 |
+
if self.need_temporal_weight:
|
121 |
+
hidden_states = (
|
122 |
+
identity + torch.abs(self.temporal_weight) * mask * hidden_states
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
hidden_states = identity + mask * hidden_states
|
126 |
+
else:
|
127 |
+
if self.need_temporal_weight:
|
128 |
+
hidden_states = (
|
129 |
+
identity + torch.abs(self.temporal_weight) * hidden_states
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
hidden_states = identity + hidden_states
|
133 |
+
hidden_states = rearrange(hidden_states, " b c t h w -> (b t) c h w")
|
134 |
+
hidden_states = hidden_states.to(dtype=hidden_states_dtype)
|
135 |
+
return hidden_states
|
musev/models/super_model.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
|
5 |
+
from typing import Any, Dict, Tuple, Union, Optional
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from torch import nn
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers.models.modeling_utils import ModelMixin, load_state_dict
|
12 |
+
|
13 |
+
from ..data.data_util import align_repeat_tensor_single_dim
|
14 |
+
|
15 |
+
from .unet_3d_condition import UNet3DConditionModel
|
16 |
+
from .referencenet import ReferenceNet2D
|
17 |
+
from ip_adapter.ip_adapter import ImageProjModel
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class SuperUNet3DConditionModel(nn.Module):
|
23 |
+
"""封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。
|
24 |
+
主要作用
|
25 |
+
1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些;
|
26 |
+
2. 便于 accelerator 的分布式训练;
|
27 |
+
|
28 |
+
wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj
|
29 |
+
1. support controlnet, referencenet, etc.
|
30 |
+
2. support accelerator distributed training
|
31 |
+
"""
|
32 |
+
|
33 |
+
_supports_gradient_checkpointing = True
|
34 |
+
print_idx = 0
|
35 |
+
|
36 |
+
# @register_to_config
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
unet: nn.Module,
|
40 |
+
referencenet: nn.Module = None,
|
41 |
+
controlnet: nn.Module = None,
|
42 |
+
vae: nn.Module = None,
|
43 |
+
text_encoder: nn.Module = None,
|
44 |
+
tokenizer: nn.Module = None,
|
45 |
+
text_emb_extractor: nn.Module = None,
|
46 |
+
clip_vision_extractor: nn.Module = None,
|
47 |
+
ip_adapter_image_proj: nn.Module = None,
|
48 |
+
) -> None:
|
49 |
+
"""_summary_
|
50 |
+
|
51 |
+
Args:
|
52 |
+
unet (nn.Module): _description_
|
53 |
+
referencenet (nn.Module, optional): _description_. Defaults to None.
|
54 |
+
controlnet (nn.Module, optional): _description_. Defaults to None.
|
55 |
+
vae (nn.Module, optional): _description_. Defaults to None.
|
56 |
+
text_encoder (nn.Module, optional): _description_. Defaults to None.
|
57 |
+
tokenizer (nn.Module, optional): _description_. Defaults to None.
|
58 |
+
text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None.
|
59 |
+
clip_vision_extractor (nn.Module, optional): _description_. Defaults to None.
|
60 |
+
"""
|
61 |
+
super().__init__()
|
62 |
+
self.unet = unet
|
63 |
+
self.referencenet = referencenet
|
64 |
+
self.controlnet = controlnet
|
65 |
+
self.vae = vae
|
66 |
+
self.text_encoder = text_encoder
|
67 |
+
self.tokenizer = tokenizer
|
68 |
+
self.text_emb_extractor = text_emb_extractor
|
69 |
+
self.clip_vision_extractor = clip_vision_extractor
|
70 |
+
self.ip_adapter_image_proj = ip_adapter_image_proj
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self,
|
74 |
+
unet_params: Dict,
|
75 |
+
encoder_hidden_states: torch.Tensor,
|
76 |
+
referencenet_params: Dict = None,
|
77 |
+
controlnet_params: Dict = None,
|
78 |
+
controlnet_scale: float = 1.0,
|
79 |
+
vision_clip_emb: Union[torch.Tensor, None] = None,
|
80 |
+
prompt_only_use_image_prompt: bool = False,
|
81 |
+
):
|
82 |
+
"""_summary_
|
83 |
+
|
84 |
+
Args:
|
85 |
+
unet_params (Dict): _description_
|
86 |
+
encoder_hidden_states (torch.Tensor): b t n d
|
87 |
+
referencenet_params (Dict, optional): _description_. Defaults to None.
|
88 |
+
controlnet_params (Dict, optional): _description_. Defaults to None.
|
89 |
+
controlnet_scale (float, optional): _description_. Defaults to 1.0.
|
90 |
+
vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None.
|
91 |
+
prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
_type_: _description_
|
95 |
+
"""
|
96 |
+
batch_size = unet_params["sample"].shape[0]
|
97 |
+
time_size = unet_params["sample"].shape[2]
|
98 |
+
|
99 |
+
# ip_adapter_cross_attn, prepare image prompt
|
100 |
+
if vision_clip_emb is not None:
|
101 |
+
# b t n d -> b t n d
|
102 |
+
if self.print_idx == 0:
|
103 |
+
logger.debug(
|
104 |
+
f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
|
105 |
+
)
|
106 |
+
if vision_clip_emb.ndim == 3:
|
107 |
+
vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d")
|
108 |
+
if self.ip_adapter_image_proj is not None:
|
109 |
+
vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d")
|
110 |
+
vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb)
|
111 |
+
if self.print_idx == 0:
|
112 |
+
logger.debug(
|
113 |
+
f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
|
114 |
+
)
|
115 |
+
if vision_clip_emb.ndim == 2:
|
116 |
+
vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d")
|
117 |
+
vision_clip_emb = rearrange(
|
118 |
+
vision_clip_emb, "(b t) n d -> b t n d", b=batch_size
|
119 |
+
)
|
120 |
+
vision_clip_emb = align_repeat_tensor_single_dim(
|
121 |
+
vision_clip_emb, target_length=time_size, dim=1
|
122 |
+
)
|
123 |
+
if self.print_idx == 0:
|
124 |
+
logger.debug(
|
125 |
+
f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
|
126 |
+
)
|
127 |
+
|
128 |
+
if vision_clip_emb is None and encoder_hidden_states is not None:
|
129 |
+
vision_clip_emb = encoder_hidden_states
|
130 |
+
if vision_clip_emb is not None and encoder_hidden_states is None:
|
131 |
+
encoder_hidden_states = vision_clip_emb
|
132 |
+
# 当 prompt_only_use_image_prompt 为True时,
|
133 |
+
# 1. referencenet 都使用 vision_clip_emb
|
134 |
+
# 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新
|
135 |
+
# 3. controlnet 当前使用 text_prompt
|
136 |
+
|
137 |
+
# when prompt_only_use_image_prompt True,
|
138 |
+
# 1. referencenet use vision_clip_emb
|
139 |
+
# 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update
|
140 |
+
# 3. controlnet use text_prompt
|
141 |
+
|
142 |
+
# extract referencenet emb
|
143 |
+
if self.referencenet is not None and referencenet_params is not None:
|
144 |
+
referencenet_encoder_hidden_states = align_repeat_tensor_single_dim(
|
145 |
+
vision_clip_emb,
|
146 |
+
target_length=referencenet_params["num_frames"],
|
147 |
+
dim=1,
|
148 |
+
)
|
149 |
+
referencenet_params["encoder_hidden_states"] = rearrange(
|
150 |
+
referencenet_encoder_hidden_states, "b t n d->(b t) n d"
|
151 |
+
)
|
152 |
+
referencenet_out = self.referencenet(**referencenet_params)
|
153 |
+
(
|
154 |
+
down_block_refer_embs,
|
155 |
+
mid_block_refer_emb,
|
156 |
+
refer_self_attn_emb,
|
157 |
+
) = referencenet_out
|
158 |
+
if down_block_refer_embs is not None:
|
159 |
+
if self.print_idx == 0:
|
160 |
+
logger.debug(
|
161 |
+
f"len(down_block_refer_embs)={len(down_block_refer_embs)}"
|
162 |
+
)
|
163 |
+
for i, down_emb in enumerate(down_block_refer_embs):
|
164 |
+
if self.print_idx == 0:
|
165 |
+
logger.debug(
|
166 |
+
f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}"
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
if self.print_idx == 0:
|
170 |
+
logger.debug(f"down_block_refer_embs is None")
|
171 |
+
if mid_block_refer_emb is not None:
|
172 |
+
if self.print_idx == 0:
|
173 |
+
logger.debug(
|
174 |
+
f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}"
|
175 |
+
)
|
176 |
+
else:
|
177 |
+
if self.print_idx == 0:
|
178 |
+
logger.debug(f"mid_block_refer_emb is None")
|
179 |
+
if refer_self_attn_emb is not None:
|
180 |
+
if self.print_idx == 0:
|
181 |
+
logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}")
|
182 |
+
for i, self_attn_emb in enumerate(refer_self_attn_emb):
|
183 |
+
if self.print_idx == 0:
|
184 |
+
logger.debug(
|
185 |
+
f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}"
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
if self.print_idx == 0:
|
189 |
+
logger.debug(f"refer_self_attn_emb is None")
|
190 |
+
else:
|
191 |
+
down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = (
|
192 |
+
None,
|
193 |
+
None,
|
194 |
+
None,
|
195 |
+
)
|
196 |
+
|
197 |
+
# extract controlnet emb
|
198 |
+
if self.controlnet is not None and controlnet_params is not None:
|
199 |
+
controlnet_encoder_hidden_states = align_repeat_tensor_single_dim(
|
200 |
+
encoder_hidden_states,
|
201 |
+
target_length=unet_params["sample"].shape[2],
|
202 |
+
dim=1,
|
203 |
+
)
|
204 |
+
controlnet_params["encoder_hidden_states"] = rearrange(
|
205 |
+
controlnet_encoder_hidden_states, " b t n d -> (b t) n d"
|
206 |
+
)
|
207 |
+
(
|
208 |
+
down_block_additional_residuals,
|
209 |
+
mid_block_additional_residual,
|
210 |
+
) = self.controlnet(**controlnet_params)
|
211 |
+
if controlnet_scale != 1.0:
|
212 |
+
down_block_additional_residuals = [
|
213 |
+
x * controlnet_scale for x in down_block_additional_residuals
|
214 |
+
]
|
215 |
+
mid_block_additional_residual = (
|
216 |
+
mid_block_additional_residual * controlnet_scale
|
217 |
+
)
|
218 |
+
for i, down_block_additional_residual in enumerate(
|
219 |
+
down_block_additional_residuals
|
220 |
+
):
|
221 |
+
if self.print_idx == 0:
|
222 |
+
logger.debug(
|
223 |
+
f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}"
|
224 |
+
)
|
225 |
+
|
226 |
+
if self.print_idx == 0:
|
227 |
+
logger.debug(
|
228 |
+
f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}"
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
down_block_additional_residuals = None
|
232 |
+
mid_block_additional_residual = None
|
233 |
+
|
234 |
+
if prompt_only_use_image_prompt and vision_clip_emb is not None:
|
235 |
+
encoder_hidden_states = vision_clip_emb
|
236 |
+
|
237 |
+
# run unet
|
238 |
+
out = self.unet(
|
239 |
+
**unet_params,
|
240 |
+
down_block_refer_embs=down_block_refer_embs,
|
241 |
+
mid_block_refer_emb=mid_block_refer_emb,
|
242 |
+
refer_self_attn_emb=refer_self_attn_emb,
|
243 |
+
down_block_additional_residuals=down_block_additional_residuals,
|
244 |
+
mid_block_additional_residual=mid_block_additional_residual,
|
245 |
+
encoder_hidden_states=encoder_hidden_states,
|
246 |
+
vision_clip_emb=vision_clip_emb,
|
247 |
+
)
|
248 |
+
self.print_idx += 1
|
249 |
+
return out
|
250 |
+
|
251 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
252 |
+
if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)):
|
253 |
+
module.gradient_checkpointing = value
|
musev/models/temporal_transformer.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/transformer_temporal.py
|
16 |
+
from __future__ import annotations
|
17 |
+
from copy import deepcopy
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import List, Literal, Optional
|
20 |
+
import logging
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
from einops import rearrange, repeat
|
25 |
+
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.utils import BaseOutput
|
28 |
+
from diffusers.models.modeling_utils import ModelMixin
|
29 |
+
from diffusers.models.transformer_temporal import (
|
30 |
+
TransformerTemporalModelOutput,
|
31 |
+
TransformerTemporalModel as DiffusersTransformerTemporalModel,
|
32 |
+
)
|
33 |
+
from diffusers.models.attention_processor import AttnProcessor
|
34 |
+
|
35 |
+
from mmcm.utils.gpu_util import get_gpu_status
|
36 |
+
from ..data.data_util import (
|
37 |
+
batch_concat_two_tensor_with_index,
|
38 |
+
batch_index_fill,
|
39 |
+
batch_index_select,
|
40 |
+
concat_two_tensor,
|
41 |
+
align_repeat_tensor_single_dim,
|
42 |
+
)
|
43 |
+
from ..utils.attention_util import generate_sparse_causcal_attn_mask
|
44 |
+
from .attention import BasicTransformerBlock
|
45 |
+
from .attention_processor import (
|
46 |
+
BaseIPAttnProcessor,
|
47 |
+
)
|
48 |
+
from . import Model_Register
|
49 |
+
|
50 |
+
# https://github.com/facebookresearch/xformers/issues/845
|
51 |
+
# 输入bs*n_frames*w*h太高,xformers报错。因此将transformer_temporal的allow_xformers均关掉
|
52 |
+
# if bs*n_frames*w*h to large, xformers will raise error. So we close the allow_xformers in transformer_temporal
|
53 |
+
logger = logging.getLogger(__name__)
|
54 |
+
|
55 |
+
|
56 |
+
@Model_Register.register
|
57 |
+
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
58 |
+
"""
|
59 |
+
Transformer model for video-like data.
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
63 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
64 |
+
in_channels (`int`, *optional*):
|
65 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
66 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
67 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
68 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
69 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
70 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
71 |
+
`ImagePositionalEmbeddings`.
|
72 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
73 |
+
attention_bias (`bool`, *optional*):
|
74 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
75 |
+
double_self_attention (`bool`, *optional*):
|
76 |
+
Configure if each TransformerBlock should contain two self-attention layers
|
77 |
+
"""
|
78 |
+
|
79 |
+
@register_to_config
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
num_attention_heads: int = 16,
|
83 |
+
attention_head_dim: int = 88,
|
84 |
+
in_channels: Optional[int] = None,
|
85 |
+
out_channels: Optional[int] = None,
|
86 |
+
num_layers: int = 1,
|
87 |
+
femb_channels: Optional[int] = None,
|
88 |
+
dropout: float = 0.0,
|
89 |
+
norm_num_groups: int = 32,
|
90 |
+
cross_attention_dim: Optional[int] = None,
|
91 |
+
attention_bias: bool = False,
|
92 |
+
sample_size: Optional[int] = None,
|
93 |
+
activation_fn: str = "geglu",
|
94 |
+
norm_elementwise_affine: bool = True,
|
95 |
+
double_self_attention: bool = True,
|
96 |
+
allow_xformers: bool = False,
|
97 |
+
only_cross_attention: bool = False,
|
98 |
+
keep_content_condition: bool = False,
|
99 |
+
need_spatial_position_emb: bool = False,
|
100 |
+
need_temporal_weight: bool = True,
|
101 |
+
self_attn_mask: str = None,
|
102 |
+
# TODO: 运行参数,有待改到forward里面去
|
103 |
+
# TODO: running parameters, need to be moved to forward
|
104 |
+
image_scale: float = 1.0,
|
105 |
+
processor: AttnProcessor | None = None,
|
106 |
+
remove_femb_non_linear: bool = False,
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.num_attention_heads = num_attention_heads
|
111 |
+
self.attention_head_dim = attention_head_dim
|
112 |
+
|
113 |
+
inner_dim = num_attention_heads * attention_head_dim
|
114 |
+
self.inner_dim = inner_dim
|
115 |
+
self.in_channels = in_channels
|
116 |
+
|
117 |
+
self.norm = torch.nn.GroupNorm(
|
118 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
119 |
+
)
|
120 |
+
|
121 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
122 |
+
|
123 |
+
# 2. Define temporal positional embedding
|
124 |
+
self.frame_emb_proj = torch.nn.Linear(femb_channels, inner_dim)
|
125 |
+
self.remove_femb_non_linear = remove_femb_non_linear
|
126 |
+
if not remove_femb_non_linear:
|
127 |
+
self.nonlinearity = nn.SiLU()
|
128 |
+
|
129 |
+
# spatial_position_emb 使用femb_的参数配置
|
130 |
+
self.need_spatial_position_emb = need_spatial_position_emb
|
131 |
+
if need_spatial_position_emb:
|
132 |
+
self.spatial_position_emb_proj = torch.nn.Linear(femb_channels, inner_dim)
|
133 |
+
# 3. Define transformers blocks
|
134 |
+
# TODO: 该实现方式不好,待优化
|
135 |
+
# TODO: bad implementation, need to be optimized
|
136 |
+
self.need_ipadapter = False
|
137 |
+
self.cross_attn_temporal_cond = False
|
138 |
+
self.allow_xformers = allow_xformers
|
139 |
+
if processor is not None and isinstance(processor, BaseIPAttnProcessor):
|
140 |
+
self.cross_attn_temporal_cond = True
|
141 |
+
self.allow_xformers = False
|
142 |
+
if "NonParam" not in processor.__class__.__name__:
|
143 |
+
self.need_ipadapter = True
|
144 |
+
|
145 |
+
self.transformer_blocks = nn.ModuleList(
|
146 |
+
[
|
147 |
+
BasicTransformerBlock(
|
148 |
+
inner_dim,
|
149 |
+
num_attention_heads,
|
150 |
+
attention_head_dim,
|
151 |
+
dropout=dropout,
|
152 |
+
cross_attention_dim=cross_attention_dim,
|
153 |
+
activation_fn=activation_fn,
|
154 |
+
attention_bias=attention_bias,
|
155 |
+
double_self_attention=double_self_attention,
|
156 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
157 |
+
allow_xformers=allow_xformers,
|
158 |
+
only_cross_attention=only_cross_attention,
|
159 |
+
cross_attn_temporal_cond=self.need_ipadapter,
|
160 |
+
image_scale=image_scale,
|
161 |
+
processor=processor,
|
162 |
+
)
|
163 |
+
for d in range(num_layers)
|
164 |
+
]
|
165 |
+
)
|
166 |
+
|
167 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
168 |
+
|
169 |
+
self.need_temporal_weight = need_temporal_weight
|
170 |
+
if need_temporal_weight:
|
171 |
+
self.temporal_weight = nn.Parameter(
|
172 |
+
torch.tensor(
|
173 |
+
[
|
174 |
+
1e-5,
|
175 |
+
]
|
176 |
+
)
|
177 |
+
) # initialize parameter with 0
|
178 |
+
self.skip_temporal_layers = False # Whether to skip temporal layer
|
179 |
+
self.keep_content_condition = keep_content_condition
|
180 |
+
self.self_attn_mask = self_attn_mask
|
181 |
+
self.only_cross_attention = only_cross_attention
|
182 |
+
self.double_self_attention = double_self_attention
|
183 |
+
self.cross_attention_dim = cross_attention_dim
|
184 |
+
self.image_scale = image_scale
|
185 |
+
# zero out the last layer params,so the conv block is identity
|
186 |
+
nn.init.zeros_(self.proj_out.weight)
|
187 |
+
nn.init.zeros_(self.proj_out.bias)
|
188 |
+
|
189 |
+
def forward(
|
190 |
+
self,
|
191 |
+
hidden_states,
|
192 |
+
femb,
|
193 |
+
encoder_hidden_states=None,
|
194 |
+
timestep=None,
|
195 |
+
class_labels=None,
|
196 |
+
num_frames=1,
|
197 |
+
cross_attention_kwargs=None,
|
198 |
+
sample_index: torch.LongTensor = None,
|
199 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
200 |
+
spatial_position_emb: torch.Tensor = None,
|
201 |
+
return_dict: bool = True,
|
202 |
+
):
|
203 |
+
"""
|
204 |
+
Args:
|
205 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
206 |
+
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
207 |
+
hidden_states
|
208 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
209 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
210 |
+
self-attention.
|
211 |
+
timestep ( `torch.long`, *optional*):
|
212 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
213 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
214 |
+
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
|
215 |
+
conditioning.
|
216 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
217 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
[`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
|
221 |
+
[`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
|
222 |
+
When returning a tuple, the first element is the sample tensor.
|
223 |
+
"""
|
224 |
+
if self.skip_temporal_layers is True:
|
225 |
+
if not return_dict:
|
226 |
+
return (hidden_states,)
|
227 |
+
|
228 |
+
return TransformerTemporalModelOutput(sample=hidden_states)
|
229 |
+
|
230 |
+
# 1. Input
|
231 |
+
batch_frames, channel, height, width = hidden_states.shape
|
232 |
+
batch_size = batch_frames // num_frames
|
233 |
+
|
234 |
+
hidden_states = rearrange(
|
235 |
+
hidden_states, "(b t) c h w -> b c t h w", b=batch_size
|
236 |
+
)
|
237 |
+
residual = hidden_states
|
238 |
+
|
239 |
+
hidden_states = self.norm(hidden_states)
|
240 |
+
|
241 |
+
hidden_states = rearrange(hidden_states, "b c t h w -> (b h w) t c")
|
242 |
+
|
243 |
+
hidden_states = self.proj_in(hidden_states)
|
244 |
+
|
245 |
+
# 2 Positional embedding
|
246 |
+
# adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py#L574
|
247 |
+
if not self.remove_femb_non_linear:
|
248 |
+
femb = self.nonlinearity(femb)
|
249 |
+
femb = self.frame_emb_proj(femb)
|
250 |
+
femb = align_repeat_tensor_single_dim(femb, hidden_states.shape[0], dim=0)
|
251 |
+
hidden_states = hidden_states + femb
|
252 |
+
|
253 |
+
# 3. Blocks
|
254 |
+
if (
|
255 |
+
(self.only_cross_attention or not self.double_self_attention)
|
256 |
+
and self.cross_attention_dim is not None
|
257 |
+
and encoder_hidden_states is not None
|
258 |
+
):
|
259 |
+
encoder_hidden_states = align_repeat_tensor_single_dim(
|
260 |
+
encoder_hidden_states,
|
261 |
+
hidden_states.shape[0],
|
262 |
+
dim=0,
|
263 |
+
n_src_base_length=batch_size,
|
264 |
+
)
|
265 |
+
|
266 |
+
for i, block in enumerate(self.transformer_blocks):
|
267 |
+
hidden_states = block(
|
268 |
+
hidden_states,
|
269 |
+
encoder_hidden_states=encoder_hidden_states,
|
270 |
+
timestep=timestep,
|
271 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
272 |
+
class_labels=class_labels,
|
273 |
+
)
|
274 |
+
|
275 |
+
# 4. Output
|
276 |
+
hidden_states = self.proj_out(hidden_states)
|
277 |
+
hidden_states = rearrange(
|
278 |
+
hidden_states, "(b h w) t c -> b c t h w", b=batch_size, h=height, w=width
|
279 |
+
).contiguous()
|
280 |
+
|
281 |
+
# 保留condition对应的frames,便于保持前序内容帧,提升一致性
|
282 |
+
# keep the frames corresponding to the condition to maintain the previous content frames and improve consistency
|
283 |
+
if (
|
284 |
+
vision_conditon_frames_sample_index is not None
|
285 |
+
and self.keep_content_condition
|
286 |
+
):
|
287 |
+
mask = torch.ones_like(hidden_states, device=hidden_states.device)
|
288 |
+
mask = batch_index_fill(
|
289 |
+
mask, dim=2, index=vision_conditon_frames_sample_index, value=0
|
290 |
+
)
|
291 |
+
if self.need_temporal_weight:
|
292 |
+
output = (
|
293 |
+
residual + torch.abs(self.temporal_weight) * mask * hidden_states
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
output = residual + mask * hidden_states
|
297 |
+
else:
|
298 |
+
if self.need_temporal_weight:
|
299 |
+
output = residual + torch.abs(self.temporal_weight) * hidden_states
|
300 |
+
else:
|
301 |
+
output = residual + mask * hidden_states
|
302 |
+
|
303 |
+
# output = torch.abs(self.temporal_weight) * hidden_states + residual
|
304 |
+
output = rearrange(output, "b c t h w -> (b t) c h w")
|
305 |
+
if not return_dict:
|
306 |
+
return (output,)
|
307 |
+
|
308 |
+
return TransformerTemporalModelOutput(sample=output)
|
musev/models/text_model.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class TextEmbExtractor(nn.Module):
|
6 |
+
def __init__(self, tokenizer, text_encoder) -> None:
|
7 |
+
super(TextEmbExtractor, self).__init__()
|
8 |
+
self.tokenizer = tokenizer
|
9 |
+
self.text_encoder = text_encoder
|
10 |
+
|
11 |
+
def forward(
|
12 |
+
self,
|
13 |
+
texts,
|
14 |
+
text_params: Dict = None,
|
15 |
+
):
|
16 |
+
if text_params is None:
|
17 |
+
text_params = {}
|
18 |
+
special_prompt_input = self.tokenizer(
|
19 |
+
texts,
|
20 |
+
max_length=self.tokenizer.model_max_length,
|
21 |
+
padding="max_length",
|
22 |
+
truncation=True,
|
23 |
+
return_tensors="pt",
|
24 |
+
)
|
25 |
+
if (
|
26 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
27 |
+
and self.text_encoder.config.use_attention_mask
|
28 |
+
):
|
29 |
+
attention_mask = special_prompt_input.attention_mask.to(
|
30 |
+
self.text_encoder.device
|
31 |
+
)
|
32 |
+
else:
|
33 |
+
attention_mask = None
|
34 |
+
|
35 |
+
embeddings = self.text_encoder(
|
36 |
+
special_prompt_input.input_ids.to(self.text_encoder.device),
|
37 |
+
attention_mask=attention_mask,
|
38 |
+
**text_params
|
39 |
+
)
|
40 |
+
return embeddings
|
musev/models/transformer_2d.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from __future__ import annotations
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Literal, Optional
|
17 |
+
import logging
|
18 |
+
|
19 |
+
from einops import rearrange
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
from diffusers.models.transformer_2d import (
|
26 |
+
Transformer2DModelOutput,
|
27 |
+
Transformer2DModel as DiffusersTransformer2DModel,
|
28 |
+
)
|
29 |
+
|
30 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
31 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
32 |
+
from diffusers.utils import BaseOutput, deprecate
|
33 |
+
from diffusers.models.attention import (
|
34 |
+
BasicTransformerBlock as DiffusersBasicTransformerBlock,
|
35 |
+
)
|
36 |
+
from diffusers.models.embeddings import PatchEmbed
|
37 |
+
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
38 |
+
from diffusers.models.modeling_utils import ModelMixin
|
39 |
+
from diffusers.utils.constants import USE_PEFT_BACKEND
|
40 |
+
|
41 |
+
from .attention import BasicTransformerBlock
|
42 |
+
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
|
45 |
+
# 本部分 与 diffusers/models/transformer_2d.py 几乎一样
|
46 |
+
# 更新部分
|
47 |
+
# 1. 替换自定义 BasicTransformerBlock 类
|
48 |
+
# 2. 在forward 里增加了 self_attn_block_embs 用于 提取 self_attn 中的emb
|
49 |
+
|
50 |
+
# this module is same as diffusers/models/transformer_2d.py. The update part is
|
51 |
+
# 1 redefine BasicTransformerBlock
|
52 |
+
# 2. add self_attn_block_embs in forward to extract emb from self_attn
|
53 |
+
|
54 |
+
|
55 |
+
class Transformer2DModel(DiffusersTransformer2DModel):
|
56 |
+
"""
|
57 |
+
A 2D Transformer model for image-like data.
|
58 |
+
|
59 |
+
Parameters:
|
60 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
61 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
62 |
+
in_channels (`int`, *optional*):
|
63 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
64 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
65 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
66 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
67 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
68 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
69 |
+
num_vector_embeds (`int`, *optional*):
|
70 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
71 |
+
Includes the class for the masked latent pixel.
|
72 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
73 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
74 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
75 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
76 |
+
added to the hidden states.
|
77 |
+
|
78 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
79 |
+
attention_bias (`bool`, *optional*):
|
80 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
81 |
+
"""
|
82 |
+
|
83 |
+
@register_to_config
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
num_attention_heads: int = 16,
|
87 |
+
attention_head_dim: int = 88,
|
88 |
+
in_channels: int | None = None,
|
89 |
+
out_channels: int | None = None,
|
90 |
+
num_layers: int = 1,
|
91 |
+
dropout: float = 0,
|
92 |
+
norm_num_groups: int = 32,
|
93 |
+
cross_attention_dim: int | None = None,
|
94 |
+
attention_bias: bool = False,
|
95 |
+
sample_size: int | None = None,
|
96 |
+
num_vector_embeds: int | None = None,
|
97 |
+
patch_size: int | None = None,
|
98 |
+
activation_fn: str = "geglu",
|
99 |
+
num_embeds_ada_norm: int | None = None,
|
100 |
+
use_linear_projection: bool = False,
|
101 |
+
only_cross_attention: bool = False,
|
102 |
+
double_self_attention: bool = False,
|
103 |
+
upcast_attention: bool = False,
|
104 |
+
norm_type: str = "layer_norm",
|
105 |
+
norm_elementwise_affine: bool = True,
|
106 |
+
attention_type: str = "default",
|
107 |
+
cross_attn_temporal_cond: bool = False,
|
108 |
+
ip_adapter_cross_attn: bool = False,
|
109 |
+
need_t2i_facein: bool = False,
|
110 |
+
need_t2i_ip_adapter_face: bool = False,
|
111 |
+
image_scale: float = 1.0,
|
112 |
+
):
|
113 |
+
super().__init__(
|
114 |
+
num_attention_heads,
|
115 |
+
attention_head_dim,
|
116 |
+
in_channels,
|
117 |
+
out_channels,
|
118 |
+
num_layers,
|
119 |
+
dropout,
|
120 |
+
norm_num_groups,
|
121 |
+
cross_attention_dim,
|
122 |
+
attention_bias,
|
123 |
+
sample_size,
|
124 |
+
num_vector_embeds,
|
125 |
+
patch_size,
|
126 |
+
activation_fn,
|
127 |
+
num_embeds_ada_norm,
|
128 |
+
use_linear_projection,
|
129 |
+
only_cross_attention,
|
130 |
+
double_self_attention,
|
131 |
+
upcast_attention,
|
132 |
+
norm_type,
|
133 |
+
norm_elementwise_affine,
|
134 |
+
attention_type,
|
135 |
+
)
|
136 |
+
inner_dim = num_attention_heads * attention_head_dim
|
137 |
+
self.transformer_blocks = nn.ModuleList(
|
138 |
+
[
|
139 |
+
BasicTransformerBlock(
|
140 |
+
inner_dim,
|
141 |
+
num_attention_heads,
|
142 |
+
attention_head_dim,
|
143 |
+
dropout=dropout,
|
144 |
+
cross_attention_dim=cross_attention_dim,
|
145 |
+
activation_fn=activation_fn,
|
146 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
147 |
+
attention_bias=attention_bias,
|
148 |
+
only_cross_attention=only_cross_attention,
|
149 |
+
double_self_attention=double_self_attention,
|
150 |
+
upcast_attention=upcast_attention,
|
151 |
+
norm_type=norm_type,
|
152 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
153 |
+
attention_type=attention_type,
|
154 |
+
cross_attn_temporal_cond=cross_attn_temporal_cond,
|
155 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
156 |
+
need_t2i_facein=need_t2i_facein,
|
157 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
158 |
+
image_scale=image_scale,
|
159 |
+
)
|
160 |
+
for d in range(num_layers)
|
161 |
+
]
|
162 |
+
)
|
163 |
+
self.num_layers = num_layers
|
164 |
+
self.cross_attn_temporal_cond = cross_attn_temporal_cond
|
165 |
+
self.ip_adapter_cross_attn = ip_adapter_cross_attn
|
166 |
+
|
167 |
+
self.need_t2i_facein = need_t2i_facein
|
168 |
+
self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
|
169 |
+
self.image_scale = image_scale
|
170 |
+
self.print_idx = 0
|
171 |
+
|
172 |
+
def forward(
|
173 |
+
self,
|
174 |
+
hidden_states: torch.Tensor,
|
175 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
176 |
+
timestep: Optional[torch.LongTensor] = None,
|
177 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
178 |
+
class_labels: Optional[torch.LongTensor] = None,
|
179 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
180 |
+
attention_mask: Optional[torch.Tensor] = None,
|
181 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
182 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
183 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
184 |
+
return_dict: bool = True,
|
185 |
+
):
|
186 |
+
"""
|
187 |
+
The [`Transformer2DModel`] forward method.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
191 |
+
Input `hidden_states`.
|
192 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
193 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
194 |
+
self-attention.
|
195 |
+
timestep ( `torch.LongTensor`, *optional*):
|
196 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
197 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
198 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
199 |
+
`AdaLayerZeroNorm`.
|
200 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
201 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
202 |
+
`self.processor` in
|
203 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
204 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
205 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
206 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
207 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
208 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
209 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
210 |
+
|
211 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
212 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
213 |
+
|
214 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
215 |
+
above. This bias will be added to the cross-attention scores.
|
216 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
217 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
218 |
+
tuple.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
222 |
+
`tuple` where the first element is the sample tensor.
|
223 |
+
"""
|
224 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
225 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
226 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
227 |
+
# expects mask of shape:
|
228 |
+
# [batch, key_tokens]
|
229 |
+
# adds singleton query_tokens dimension:
|
230 |
+
# [batch, 1, key_tokens]
|
231 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
232 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
233 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
234 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
235 |
+
# assume that mask is expressed as:
|
236 |
+
# (1 = keep, 0 = discard)
|
237 |
+
# convert mask into a bias that can be added to attention scores:
|
238 |
+
# (keep = +0, discard = -10000.0)
|
239 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
240 |
+
attention_mask = attention_mask.unsqueeze(1)
|
241 |
+
|
242 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
243 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
244 |
+
encoder_attention_mask = (
|
245 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
246 |
+
) * -10000.0
|
247 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
248 |
+
|
249 |
+
# Retrieve lora scale.
|
250 |
+
lora_scale = (
|
251 |
+
cross_attention_kwargs.get("scale", 1.0)
|
252 |
+
if cross_attention_kwargs is not None
|
253 |
+
else 1.0
|
254 |
+
)
|
255 |
+
|
256 |
+
# 1. Input
|
257 |
+
if self.is_input_continuous:
|
258 |
+
batch, _, height, width = hidden_states.shape
|
259 |
+
residual = hidden_states
|
260 |
+
|
261 |
+
hidden_states = self.norm(hidden_states)
|
262 |
+
if not self.use_linear_projection:
|
263 |
+
hidden_states = (
|
264 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
265 |
+
if not USE_PEFT_BACKEND
|
266 |
+
else self.proj_in(hidden_states)
|
267 |
+
)
|
268 |
+
inner_dim = hidden_states.shape[1]
|
269 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
270 |
+
batch, height * width, inner_dim
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
inner_dim = hidden_states.shape[1]
|
274 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
275 |
+
batch, height * width, inner_dim
|
276 |
+
)
|
277 |
+
hidden_states = (
|
278 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
279 |
+
if not USE_PEFT_BACKEND
|
280 |
+
else self.proj_in(hidden_states)
|
281 |
+
)
|
282 |
+
|
283 |
+
elif self.is_input_vectorized:
|
284 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
285 |
+
elif self.is_input_patches:
|
286 |
+
height, width = (
|
287 |
+
hidden_states.shape[-2] // self.patch_size,
|
288 |
+
hidden_states.shape[-1] // self.patch_size,
|
289 |
+
)
|
290 |
+
hidden_states = self.pos_embed(hidden_states)
|
291 |
+
|
292 |
+
if self.adaln_single is not None:
|
293 |
+
if self.use_additional_conditions and added_cond_kwargs is None:
|
294 |
+
raise ValueError(
|
295 |
+
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
296 |
+
)
|
297 |
+
batch_size = hidden_states.shape[0]
|
298 |
+
timestep, embedded_timestep = self.adaln_single(
|
299 |
+
timestep,
|
300 |
+
added_cond_kwargs,
|
301 |
+
batch_size=batch_size,
|
302 |
+
hidden_dtype=hidden_states.dtype,
|
303 |
+
)
|
304 |
+
|
305 |
+
# 2. Blocks
|
306 |
+
if self.caption_projection is not None:
|
307 |
+
batch_size = hidden_states.shape[0]
|
308 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
309 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
310 |
+
batch_size, -1, hidden_states.shape[-1]
|
311 |
+
)
|
312 |
+
|
313 |
+
for block in self.transformer_blocks:
|
314 |
+
if self.training and self.gradient_checkpointing:
|
315 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
316 |
+
block,
|
317 |
+
hidden_states,
|
318 |
+
attention_mask,
|
319 |
+
encoder_hidden_states,
|
320 |
+
encoder_attention_mask,
|
321 |
+
timestep,
|
322 |
+
cross_attention_kwargs,
|
323 |
+
class_labels,
|
324 |
+
self_attn_block_embs,
|
325 |
+
self_attn_block_embs_mode,
|
326 |
+
use_reentrant=False,
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
hidden_states = block(
|
330 |
+
hidden_states,
|
331 |
+
attention_mask=attention_mask,
|
332 |
+
encoder_hidden_states=encoder_hidden_states,
|
333 |
+
encoder_attention_mask=encoder_attention_mask,
|
334 |
+
timestep=timestep,
|
335 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
336 |
+
class_labels=class_labels,
|
337 |
+
self_attn_block_embs=self_attn_block_embs,
|
338 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
339 |
+
)
|
340 |
+
# 将 转换 self_attn_emb的尺寸
|
341 |
+
if (
|
342 |
+
self_attn_block_embs is not None
|
343 |
+
and self_attn_block_embs_mode.lower() == "write"
|
344 |
+
):
|
345 |
+
self_attn_idx = block.spatial_self_attn_idx
|
346 |
+
if self.print_idx == 0:
|
347 |
+
logger.debug(
|
348 |
+
f"self_attn_block_embs, num={len(self_attn_block_embs)}, before, shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}"
|
349 |
+
)
|
350 |
+
self_attn_block_embs[self_attn_idx] = rearrange(
|
351 |
+
self_attn_block_embs[self_attn_idx],
|
352 |
+
"bt (h w) c->bt c h w",
|
353 |
+
h=height,
|
354 |
+
w=width,
|
355 |
+
)
|
356 |
+
if self.print_idx == 0:
|
357 |
+
logger.debug(
|
358 |
+
f"self_attn_block_embs, num={len(self_attn_block_embs)}, after ,shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}"
|
359 |
+
)
|
360 |
+
|
361 |
+
if self.proj_out is None:
|
362 |
+
return hidden_states
|
363 |
+
|
364 |
+
# 3. Output
|
365 |
+
if self.is_input_continuous:
|
366 |
+
if not self.use_linear_projection:
|
367 |
+
hidden_states = (
|
368 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
369 |
+
.permute(0, 3, 1, 2)
|
370 |
+
.contiguous()
|
371 |
+
)
|
372 |
+
hidden_states = (
|
373 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
374 |
+
if not USE_PEFT_BACKEND
|
375 |
+
else self.proj_out(hidden_states)
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
hidden_states = (
|
379 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
380 |
+
if not USE_PEFT_BACKEND
|
381 |
+
else self.proj_out(hidden_states)
|
382 |
+
)
|
383 |
+
hidden_states = (
|
384 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
385 |
+
.permute(0, 3, 1, 2)
|
386 |
+
.contiguous()
|
387 |
+
)
|
388 |
+
|
389 |
+
output = hidden_states + residual
|
390 |
+
elif self.is_input_vectorized:
|
391 |
+
hidden_states = self.norm_out(hidden_states)
|
392 |
+
logits = self.out(hidden_states)
|
393 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
394 |
+
logits = logits.permute(0, 2, 1)
|
395 |
+
|
396 |
+
# log(p(x_0))
|
397 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
398 |
+
|
399 |
+
if self.is_input_patches:
|
400 |
+
if self.config.norm_type != "ada_norm_single":
|
401 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
402 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
403 |
+
)
|
404 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
405 |
+
hidden_states = (
|
406 |
+
self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
407 |
+
)
|
408 |
+
hidden_states = self.proj_out_2(hidden_states)
|
409 |
+
elif self.config.norm_type == "ada_norm_single":
|
410 |
+
shift, scale = (
|
411 |
+
self.scale_shift_table[None] + embedded_timestep[:, None]
|
412 |
+
).chunk(2, dim=1)
|
413 |
+
hidden_states = self.norm_out(hidden_states)
|
414 |
+
# Modulation
|
415 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
416 |
+
hidden_states = self.proj_out(hidden_states)
|
417 |
+
hidden_states = hidden_states.squeeze(1)
|
418 |
+
|
419 |
+
# unpatchify
|
420 |
+
if self.adaln_single is None:
|
421 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
422 |
+
hidden_states = hidden_states.reshape(
|
423 |
+
shape=(
|
424 |
+
-1,
|
425 |
+
height,
|
426 |
+
width,
|
427 |
+
self.patch_size,
|
428 |
+
self.patch_size,
|
429 |
+
self.out_channels,
|
430 |
+
)
|
431 |
+
)
|
432 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
433 |
+
output = hidden_states.reshape(
|
434 |
+
shape=(
|
435 |
+
-1,
|
436 |
+
self.out_channels,
|
437 |
+
height * self.patch_size,
|
438 |
+
width * self.patch_size,
|
439 |
+
)
|
440 |
+
)
|
441 |
+
self.print_idx += 1
|
442 |
+
if not return_dict:
|
443 |
+
return (output,)
|
444 |
+
|
445 |
+
return Transformer2DModelOutput(sample=output)
|
musev/models/unet_2d_blocks.py
ADDED
@@ -0,0 +1,1537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Literal, Optional, Tuple, Union, List
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from diffusers.utils import is_torch_version, logging
|
22 |
+
from diffusers.utils.torch_utils import apply_freeu
|
23 |
+
from diffusers.models.activations import get_activation
|
24 |
+
from diffusers.models.attention_processor import (
|
25 |
+
Attention,
|
26 |
+
AttnAddedKVProcessor,
|
27 |
+
AttnAddedKVProcessor2_0,
|
28 |
+
)
|
29 |
+
from diffusers.models.dual_transformer_2d import DualTransformer2DModel
|
30 |
+
from diffusers.models.normalization import AdaGroupNorm
|
31 |
+
from diffusers.models.resnet import (
|
32 |
+
Downsample2D,
|
33 |
+
FirDownsample2D,
|
34 |
+
FirUpsample2D,
|
35 |
+
KDownsample2D,
|
36 |
+
KUpsample2D,
|
37 |
+
ResnetBlock2D,
|
38 |
+
Upsample2D,
|
39 |
+
)
|
40 |
+
from diffusers.models.unet_2d_blocks import (
|
41 |
+
AttnDownBlock2D,
|
42 |
+
AttnDownEncoderBlock2D,
|
43 |
+
AttnSkipDownBlock2D,
|
44 |
+
AttnSkipUpBlock2D,
|
45 |
+
AttnUpBlock2D,
|
46 |
+
AttnUpDecoderBlock2D,
|
47 |
+
DownEncoderBlock2D,
|
48 |
+
KCrossAttnDownBlock2D,
|
49 |
+
KCrossAttnUpBlock2D,
|
50 |
+
KDownBlock2D,
|
51 |
+
KUpBlock2D,
|
52 |
+
ResnetDownsampleBlock2D,
|
53 |
+
ResnetUpsampleBlock2D,
|
54 |
+
SimpleCrossAttnDownBlock2D,
|
55 |
+
SimpleCrossAttnUpBlock2D,
|
56 |
+
SkipDownBlock2D,
|
57 |
+
SkipUpBlock2D,
|
58 |
+
UpDecoderBlock2D,
|
59 |
+
)
|
60 |
+
|
61 |
+
from .transformer_2d import Transformer2DModel
|
62 |
+
|
63 |
+
|
64 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
65 |
+
|
66 |
+
|
67 |
+
def get_down_block(
|
68 |
+
down_block_type: str,
|
69 |
+
num_layers: int,
|
70 |
+
in_channels: int,
|
71 |
+
out_channels: int,
|
72 |
+
temb_channels: int,
|
73 |
+
add_downsample: bool,
|
74 |
+
resnet_eps: float,
|
75 |
+
resnet_act_fn: str,
|
76 |
+
transformer_layers_per_block: int = 1,
|
77 |
+
num_attention_heads: Optional[int] = None,
|
78 |
+
resnet_groups: Optional[int] = None,
|
79 |
+
cross_attention_dim: Optional[int] = None,
|
80 |
+
downsample_padding: Optional[int] = None,
|
81 |
+
dual_cross_attention: bool = False,
|
82 |
+
use_linear_projection: bool = False,
|
83 |
+
only_cross_attention: bool = False,
|
84 |
+
upcast_attention: bool = False,
|
85 |
+
resnet_time_scale_shift: str = "default",
|
86 |
+
attention_type: str = "default",
|
87 |
+
resnet_skip_time_act: bool = False,
|
88 |
+
resnet_out_scale_factor: float = 1.0,
|
89 |
+
cross_attention_norm: Optional[str] = None,
|
90 |
+
attention_head_dim: Optional[int] = None,
|
91 |
+
downsample_type: Optional[str] = None,
|
92 |
+
dropout: float = 0.0,
|
93 |
+
):
|
94 |
+
# If attn head dim is not defined, we default it to the number of heads
|
95 |
+
if attention_head_dim is None:
|
96 |
+
logger.warn(
|
97 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
98 |
+
)
|
99 |
+
attention_head_dim = num_attention_heads
|
100 |
+
|
101 |
+
down_block_type = (
|
102 |
+
down_block_type[7:]
|
103 |
+
if down_block_type.startswith("UNetRes")
|
104 |
+
else down_block_type
|
105 |
+
)
|
106 |
+
if down_block_type == "DownBlock2D":
|
107 |
+
return DownBlock2D(
|
108 |
+
num_layers=num_layers,
|
109 |
+
in_channels=in_channels,
|
110 |
+
out_channels=out_channels,
|
111 |
+
temb_channels=temb_channels,
|
112 |
+
dropout=dropout,
|
113 |
+
add_downsample=add_downsample,
|
114 |
+
resnet_eps=resnet_eps,
|
115 |
+
resnet_act_fn=resnet_act_fn,
|
116 |
+
resnet_groups=resnet_groups,
|
117 |
+
downsample_padding=downsample_padding,
|
118 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
119 |
+
)
|
120 |
+
elif down_block_type == "ResnetDownsampleBlock2D":
|
121 |
+
return ResnetDownsampleBlock2D(
|
122 |
+
num_layers=num_layers,
|
123 |
+
in_channels=in_channels,
|
124 |
+
out_channels=out_channels,
|
125 |
+
temb_channels=temb_channels,
|
126 |
+
dropout=dropout,
|
127 |
+
add_downsample=add_downsample,
|
128 |
+
resnet_eps=resnet_eps,
|
129 |
+
resnet_act_fn=resnet_act_fn,
|
130 |
+
resnet_groups=resnet_groups,
|
131 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
132 |
+
skip_time_act=resnet_skip_time_act,
|
133 |
+
output_scale_factor=resnet_out_scale_factor,
|
134 |
+
)
|
135 |
+
elif down_block_type == "AttnDownBlock2D":
|
136 |
+
if add_downsample is False:
|
137 |
+
downsample_type = None
|
138 |
+
else:
|
139 |
+
downsample_type = downsample_type or "conv" # default to 'conv'
|
140 |
+
return AttnDownBlock2D(
|
141 |
+
num_layers=num_layers,
|
142 |
+
in_channels=in_channels,
|
143 |
+
out_channels=out_channels,
|
144 |
+
temb_channels=temb_channels,
|
145 |
+
dropout=dropout,
|
146 |
+
resnet_eps=resnet_eps,
|
147 |
+
resnet_act_fn=resnet_act_fn,
|
148 |
+
resnet_groups=resnet_groups,
|
149 |
+
downsample_padding=downsample_padding,
|
150 |
+
attention_head_dim=attention_head_dim,
|
151 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
152 |
+
downsample_type=downsample_type,
|
153 |
+
)
|
154 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
155 |
+
if cross_attention_dim is None:
|
156 |
+
raise ValueError(
|
157 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock2D"
|
158 |
+
)
|
159 |
+
return CrossAttnDownBlock2D(
|
160 |
+
num_layers=num_layers,
|
161 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
162 |
+
in_channels=in_channels,
|
163 |
+
out_channels=out_channels,
|
164 |
+
temb_channels=temb_channels,
|
165 |
+
dropout=dropout,
|
166 |
+
add_downsample=add_downsample,
|
167 |
+
resnet_eps=resnet_eps,
|
168 |
+
resnet_act_fn=resnet_act_fn,
|
169 |
+
resnet_groups=resnet_groups,
|
170 |
+
downsample_padding=downsample_padding,
|
171 |
+
cross_attention_dim=cross_attention_dim,
|
172 |
+
num_attention_heads=num_attention_heads,
|
173 |
+
dual_cross_attention=dual_cross_attention,
|
174 |
+
use_linear_projection=use_linear_projection,
|
175 |
+
only_cross_attention=only_cross_attention,
|
176 |
+
upcast_attention=upcast_attention,
|
177 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
178 |
+
attention_type=attention_type,
|
179 |
+
)
|
180 |
+
elif down_block_type == "SimpleCrossAttnDownBlock2D":
|
181 |
+
if cross_attention_dim is None:
|
182 |
+
raise ValueError(
|
183 |
+
"cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D"
|
184 |
+
)
|
185 |
+
return SimpleCrossAttnDownBlock2D(
|
186 |
+
num_layers=num_layers,
|
187 |
+
in_channels=in_channels,
|
188 |
+
out_channels=out_channels,
|
189 |
+
temb_channels=temb_channels,
|
190 |
+
dropout=dropout,
|
191 |
+
add_downsample=add_downsample,
|
192 |
+
resnet_eps=resnet_eps,
|
193 |
+
resnet_act_fn=resnet_act_fn,
|
194 |
+
resnet_groups=resnet_groups,
|
195 |
+
cross_attention_dim=cross_attention_dim,
|
196 |
+
attention_head_dim=attention_head_dim,
|
197 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
198 |
+
skip_time_act=resnet_skip_time_act,
|
199 |
+
output_scale_factor=resnet_out_scale_factor,
|
200 |
+
only_cross_attention=only_cross_attention,
|
201 |
+
cross_attention_norm=cross_attention_norm,
|
202 |
+
)
|
203 |
+
elif down_block_type == "SkipDownBlock2D":
|
204 |
+
return SkipDownBlock2D(
|
205 |
+
num_layers=num_layers,
|
206 |
+
in_channels=in_channels,
|
207 |
+
out_channels=out_channels,
|
208 |
+
temb_channels=temb_channels,
|
209 |
+
dropout=dropout,
|
210 |
+
add_downsample=add_downsample,
|
211 |
+
resnet_eps=resnet_eps,
|
212 |
+
resnet_act_fn=resnet_act_fn,
|
213 |
+
downsample_padding=downsample_padding,
|
214 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
215 |
+
)
|
216 |
+
elif down_block_type == "AttnSkipDownBlock2D":
|
217 |
+
return AttnSkipDownBlock2D(
|
218 |
+
num_layers=num_layers,
|
219 |
+
in_channels=in_channels,
|
220 |
+
out_channels=out_channels,
|
221 |
+
temb_channels=temb_channels,
|
222 |
+
dropout=dropout,
|
223 |
+
add_downsample=add_downsample,
|
224 |
+
resnet_eps=resnet_eps,
|
225 |
+
resnet_act_fn=resnet_act_fn,
|
226 |
+
attention_head_dim=attention_head_dim,
|
227 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
228 |
+
)
|
229 |
+
elif down_block_type == "DownEncoderBlock2D":
|
230 |
+
return DownEncoderBlock2D(
|
231 |
+
num_layers=num_layers,
|
232 |
+
in_channels=in_channels,
|
233 |
+
out_channels=out_channels,
|
234 |
+
dropout=dropout,
|
235 |
+
add_downsample=add_downsample,
|
236 |
+
resnet_eps=resnet_eps,
|
237 |
+
resnet_act_fn=resnet_act_fn,
|
238 |
+
resnet_groups=resnet_groups,
|
239 |
+
downsample_padding=downsample_padding,
|
240 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
241 |
+
)
|
242 |
+
elif down_block_type == "AttnDownEncoderBlock2D":
|
243 |
+
return AttnDownEncoderBlock2D(
|
244 |
+
num_layers=num_layers,
|
245 |
+
in_channels=in_channels,
|
246 |
+
out_channels=out_channels,
|
247 |
+
dropout=dropout,
|
248 |
+
add_downsample=add_downsample,
|
249 |
+
resnet_eps=resnet_eps,
|
250 |
+
resnet_act_fn=resnet_act_fn,
|
251 |
+
resnet_groups=resnet_groups,
|
252 |
+
downsample_padding=downsample_padding,
|
253 |
+
attention_head_dim=attention_head_dim,
|
254 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
255 |
+
)
|
256 |
+
elif down_block_type == "KDownBlock2D":
|
257 |
+
return KDownBlock2D(
|
258 |
+
num_layers=num_layers,
|
259 |
+
in_channels=in_channels,
|
260 |
+
out_channels=out_channels,
|
261 |
+
temb_channels=temb_channels,
|
262 |
+
dropout=dropout,
|
263 |
+
add_downsample=add_downsample,
|
264 |
+
resnet_eps=resnet_eps,
|
265 |
+
resnet_act_fn=resnet_act_fn,
|
266 |
+
)
|
267 |
+
elif down_block_type == "KCrossAttnDownBlock2D":
|
268 |
+
return KCrossAttnDownBlock2D(
|
269 |
+
num_layers=num_layers,
|
270 |
+
in_channels=in_channels,
|
271 |
+
out_channels=out_channels,
|
272 |
+
temb_channels=temb_channels,
|
273 |
+
dropout=dropout,
|
274 |
+
add_downsample=add_downsample,
|
275 |
+
resnet_eps=resnet_eps,
|
276 |
+
resnet_act_fn=resnet_act_fn,
|
277 |
+
cross_attention_dim=cross_attention_dim,
|
278 |
+
attention_head_dim=attention_head_dim,
|
279 |
+
add_self_attention=True if not add_downsample else False,
|
280 |
+
)
|
281 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
282 |
+
|
283 |
+
|
284 |
+
def get_up_block(
|
285 |
+
up_block_type: str,
|
286 |
+
num_layers: int,
|
287 |
+
in_channels: int,
|
288 |
+
out_channels: int,
|
289 |
+
prev_output_channel: int,
|
290 |
+
temb_channels: int,
|
291 |
+
add_upsample: bool,
|
292 |
+
resnet_eps: float,
|
293 |
+
resnet_act_fn: str,
|
294 |
+
resolution_idx: Optional[int] = None,
|
295 |
+
transformer_layers_per_block: int = 1,
|
296 |
+
num_attention_heads: Optional[int] = None,
|
297 |
+
resnet_groups: Optional[int] = None,
|
298 |
+
cross_attention_dim: Optional[int] = None,
|
299 |
+
dual_cross_attention: bool = False,
|
300 |
+
use_linear_projection: bool = False,
|
301 |
+
only_cross_attention: bool = False,
|
302 |
+
upcast_attention: bool = False,
|
303 |
+
resnet_time_scale_shift: str = "default",
|
304 |
+
attention_type: str = "default",
|
305 |
+
resnet_skip_time_act: bool = False,
|
306 |
+
resnet_out_scale_factor: float = 1.0,
|
307 |
+
cross_attention_norm: Optional[str] = None,
|
308 |
+
attention_head_dim: Optional[int] = None,
|
309 |
+
upsample_type: Optional[str] = None,
|
310 |
+
dropout: float = 0.0,
|
311 |
+
) -> nn.Module:
|
312 |
+
# If attn head dim is not defined, we default it to the number of heads
|
313 |
+
if attention_head_dim is None:
|
314 |
+
logger.warn(
|
315 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
316 |
+
)
|
317 |
+
attention_head_dim = num_attention_heads
|
318 |
+
|
319 |
+
up_block_type = (
|
320 |
+
up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
321 |
+
)
|
322 |
+
if up_block_type == "UpBlock2D":
|
323 |
+
return UpBlock2D(
|
324 |
+
num_layers=num_layers,
|
325 |
+
in_channels=in_channels,
|
326 |
+
out_channels=out_channels,
|
327 |
+
prev_output_channel=prev_output_channel,
|
328 |
+
temb_channels=temb_channels,
|
329 |
+
resolution_idx=resolution_idx,
|
330 |
+
dropout=dropout,
|
331 |
+
add_upsample=add_upsample,
|
332 |
+
resnet_eps=resnet_eps,
|
333 |
+
resnet_act_fn=resnet_act_fn,
|
334 |
+
resnet_groups=resnet_groups,
|
335 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
336 |
+
)
|
337 |
+
elif up_block_type == "ResnetUpsampleBlock2D":
|
338 |
+
return ResnetUpsampleBlock2D(
|
339 |
+
num_layers=num_layers,
|
340 |
+
in_channels=in_channels,
|
341 |
+
out_channels=out_channels,
|
342 |
+
prev_output_channel=prev_output_channel,
|
343 |
+
temb_channels=temb_channels,
|
344 |
+
resolution_idx=resolution_idx,
|
345 |
+
dropout=dropout,
|
346 |
+
add_upsample=add_upsample,
|
347 |
+
resnet_eps=resnet_eps,
|
348 |
+
resnet_act_fn=resnet_act_fn,
|
349 |
+
resnet_groups=resnet_groups,
|
350 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
351 |
+
skip_time_act=resnet_skip_time_act,
|
352 |
+
output_scale_factor=resnet_out_scale_factor,
|
353 |
+
)
|
354 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
355 |
+
if cross_attention_dim is None:
|
356 |
+
raise ValueError(
|
357 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock2D"
|
358 |
+
)
|
359 |
+
return CrossAttnUpBlock2D(
|
360 |
+
num_layers=num_layers,
|
361 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
362 |
+
in_channels=in_channels,
|
363 |
+
out_channels=out_channels,
|
364 |
+
prev_output_channel=prev_output_channel,
|
365 |
+
temb_channels=temb_channels,
|
366 |
+
resolution_idx=resolution_idx,
|
367 |
+
dropout=dropout,
|
368 |
+
add_upsample=add_upsample,
|
369 |
+
resnet_eps=resnet_eps,
|
370 |
+
resnet_act_fn=resnet_act_fn,
|
371 |
+
resnet_groups=resnet_groups,
|
372 |
+
cross_attention_dim=cross_attention_dim,
|
373 |
+
num_attention_heads=num_attention_heads,
|
374 |
+
dual_cross_attention=dual_cross_attention,
|
375 |
+
use_linear_projection=use_linear_projection,
|
376 |
+
only_cross_attention=only_cross_attention,
|
377 |
+
upcast_attention=upcast_attention,
|
378 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
379 |
+
attention_type=attention_type,
|
380 |
+
)
|
381 |
+
elif up_block_type == "SimpleCrossAttnUpBlock2D":
|
382 |
+
if cross_attention_dim is None:
|
383 |
+
raise ValueError(
|
384 |
+
"cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D"
|
385 |
+
)
|
386 |
+
return SimpleCrossAttnUpBlock2D(
|
387 |
+
num_layers=num_layers,
|
388 |
+
in_channels=in_channels,
|
389 |
+
out_channels=out_channels,
|
390 |
+
prev_output_channel=prev_output_channel,
|
391 |
+
temb_channels=temb_channels,
|
392 |
+
resolution_idx=resolution_idx,
|
393 |
+
dropout=dropout,
|
394 |
+
add_upsample=add_upsample,
|
395 |
+
resnet_eps=resnet_eps,
|
396 |
+
resnet_act_fn=resnet_act_fn,
|
397 |
+
resnet_groups=resnet_groups,
|
398 |
+
cross_attention_dim=cross_attention_dim,
|
399 |
+
attention_head_dim=attention_head_dim,
|
400 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
401 |
+
skip_time_act=resnet_skip_time_act,
|
402 |
+
output_scale_factor=resnet_out_scale_factor,
|
403 |
+
only_cross_attention=only_cross_attention,
|
404 |
+
cross_attention_norm=cross_attention_norm,
|
405 |
+
)
|
406 |
+
elif up_block_type == "AttnUpBlock2D":
|
407 |
+
if add_upsample is False:
|
408 |
+
upsample_type = None
|
409 |
+
else:
|
410 |
+
upsample_type = upsample_type or "conv" # default to 'conv'
|
411 |
+
|
412 |
+
return AttnUpBlock2D(
|
413 |
+
num_layers=num_layers,
|
414 |
+
in_channels=in_channels,
|
415 |
+
out_channels=out_channels,
|
416 |
+
prev_output_channel=prev_output_channel,
|
417 |
+
temb_channels=temb_channels,
|
418 |
+
resolution_idx=resolution_idx,
|
419 |
+
dropout=dropout,
|
420 |
+
resnet_eps=resnet_eps,
|
421 |
+
resnet_act_fn=resnet_act_fn,
|
422 |
+
resnet_groups=resnet_groups,
|
423 |
+
attention_head_dim=attention_head_dim,
|
424 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
425 |
+
upsample_type=upsample_type,
|
426 |
+
)
|
427 |
+
elif up_block_type == "SkipUpBlock2D":
|
428 |
+
return SkipUpBlock2D(
|
429 |
+
num_layers=num_layers,
|
430 |
+
in_channels=in_channels,
|
431 |
+
out_channels=out_channels,
|
432 |
+
prev_output_channel=prev_output_channel,
|
433 |
+
temb_channels=temb_channels,
|
434 |
+
resolution_idx=resolution_idx,
|
435 |
+
dropout=dropout,
|
436 |
+
add_upsample=add_upsample,
|
437 |
+
resnet_eps=resnet_eps,
|
438 |
+
resnet_act_fn=resnet_act_fn,
|
439 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
440 |
+
)
|
441 |
+
elif up_block_type == "AttnSkipUpBlock2D":
|
442 |
+
return AttnSkipUpBlock2D(
|
443 |
+
num_layers=num_layers,
|
444 |
+
in_channels=in_channels,
|
445 |
+
out_channels=out_channels,
|
446 |
+
prev_output_channel=prev_output_channel,
|
447 |
+
temb_channels=temb_channels,
|
448 |
+
resolution_idx=resolution_idx,
|
449 |
+
dropout=dropout,
|
450 |
+
add_upsample=add_upsample,
|
451 |
+
resnet_eps=resnet_eps,
|
452 |
+
resnet_act_fn=resnet_act_fn,
|
453 |
+
attention_head_dim=attention_head_dim,
|
454 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
455 |
+
)
|
456 |
+
elif up_block_type == "UpDecoderBlock2D":
|
457 |
+
return UpDecoderBlock2D(
|
458 |
+
num_layers=num_layers,
|
459 |
+
in_channels=in_channels,
|
460 |
+
out_channels=out_channels,
|
461 |
+
resolution_idx=resolution_idx,
|
462 |
+
dropout=dropout,
|
463 |
+
add_upsample=add_upsample,
|
464 |
+
resnet_eps=resnet_eps,
|
465 |
+
resnet_act_fn=resnet_act_fn,
|
466 |
+
resnet_groups=resnet_groups,
|
467 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
468 |
+
temb_channels=temb_channels,
|
469 |
+
)
|
470 |
+
elif up_block_type == "AttnUpDecoderBlock2D":
|
471 |
+
return AttnUpDecoderBlock2D(
|
472 |
+
num_layers=num_layers,
|
473 |
+
in_channels=in_channels,
|
474 |
+
out_channels=out_channels,
|
475 |
+
resolution_idx=resolution_idx,
|
476 |
+
dropout=dropout,
|
477 |
+
add_upsample=add_upsample,
|
478 |
+
resnet_eps=resnet_eps,
|
479 |
+
resnet_act_fn=resnet_act_fn,
|
480 |
+
resnet_groups=resnet_groups,
|
481 |
+
attention_head_dim=attention_head_dim,
|
482 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
483 |
+
temb_channels=temb_channels,
|
484 |
+
)
|
485 |
+
elif up_block_type == "KUpBlock2D":
|
486 |
+
return KUpBlock2D(
|
487 |
+
num_layers=num_layers,
|
488 |
+
in_channels=in_channels,
|
489 |
+
out_channels=out_channels,
|
490 |
+
temb_channels=temb_channels,
|
491 |
+
resolution_idx=resolution_idx,
|
492 |
+
dropout=dropout,
|
493 |
+
add_upsample=add_upsample,
|
494 |
+
resnet_eps=resnet_eps,
|
495 |
+
resnet_act_fn=resnet_act_fn,
|
496 |
+
)
|
497 |
+
elif up_block_type == "KCrossAttnUpBlock2D":
|
498 |
+
return KCrossAttnUpBlock2D(
|
499 |
+
num_layers=num_layers,
|
500 |
+
in_channels=in_channels,
|
501 |
+
out_channels=out_channels,
|
502 |
+
temb_channels=temb_channels,
|
503 |
+
resolution_idx=resolution_idx,
|
504 |
+
dropout=dropout,
|
505 |
+
add_upsample=add_upsample,
|
506 |
+
resnet_eps=resnet_eps,
|
507 |
+
resnet_act_fn=resnet_act_fn,
|
508 |
+
cross_attention_dim=cross_attention_dim,
|
509 |
+
attention_head_dim=attention_head_dim,
|
510 |
+
)
|
511 |
+
|
512 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
513 |
+
|
514 |
+
|
515 |
+
class UNetMidBlock2D(nn.Module):
|
516 |
+
"""
|
517 |
+
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
in_channels (`int`): The number of input channels.
|
521 |
+
temb_channels (`int`): The number of temporal embedding channels.
|
522 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
523 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
524 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
525 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
|
526 |
+
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
|
527 |
+
model on tasks with long-range temporal dependencies.
|
528 |
+
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
|
529 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
530 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
531 |
+
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
|
532 |
+
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
|
533 |
+
Whether to use pre-normalization for the resnet blocks.
|
534 |
+
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
|
535 |
+
attention_head_dim (`int`, *optional*, defaults to 1):
|
536 |
+
Dimension of a single attention head. The number of attention heads is determined based on this value and
|
537 |
+
the number of input channels.
|
538 |
+
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
539 |
+
|
540 |
+
Returns:
|
541 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
542 |
+
in_channels, height, width)`.
|
543 |
+
|
544 |
+
"""
|
545 |
+
|
546 |
+
def __init__(
|
547 |
+
self,
|
548 |
+
in_channels: int,
|
549 |
+
temb_channels: int,
|
550 |
+
dropout: float = 0.0,
|
551 |
+
num_layers: int = 1,
|
552 |
+
resnet_eps: float = 1e-6,
|
553 |
+
resnet_time_scale_shift: str = "default", # default, spatial
|
554 |
+
resnet_act_fn: str = "swish",
|
555 |
+
resnet_groups: int = 32,
|
556 |
+
attn_groups: Optional[int] = None,
|
557 |
+
resnet_pre_norm: bool = True,
|
558 |
+
add_attention: bool = True,
|
559 |
+
attention_head_dim: int = 1,
|
560 |
+
output_scale_factor: float = 1.0,
|
561 |
+
):
|
562 |
+
super().__init__()
|
563 |
+
resnet_groups = (
|
564 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
565 |
+
)
|
566 |
+
self.add_attention = add_attention
|
567 |
+
|
568 |
+
if attn_groups is None:
|
569 |
+
attn_groups = (
|
570 |
+
resnet_groups if resnet_time_scale_shift == "default" else None
|
571 |
+
)
|
572 |
+
|
573 |
+
# there is always at least one resnet
|
574 |
+
resnets = [
|
575 |
+
ResnetBlock2D(
|
576 |
+
in_channels=in_channels,
|
577 |
+
out_channels=in_channels,
|
578 |
+
temb_channels=temb_channels,
|
579 |
+
eps=resnet_eps,
|
580 |
+
groups=resnet_groups,
|
581 |
+
dropout=dropout,
|
582 |
+
time_embedding_norm=resnet_time_scale_shift,
|
583 |
+
non_linearity=resnet_act_fn,
|
584 |
+
output_scale_factor=output_scale_factor,
|
585 |
+
pre_norm=resnet_pre_norm,
|
586 |
+
)
|
587 |
+
]
|
588 |
+
attentions = []
|
589 |
+
|
590 |
+
if attention_head_dim is None:
|
591 |
+
logger.warn(
|
592 |
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
593 |
+
)
|
594 |
+
attention_head_dim = in_channels
|
595 |
+
|
596 |
+
for _ in range(num_layers):
|
597 |
+
if self.add_attention:
|
598 |
+
attentions.append(
|
599 |
+
Attention(
|
600 |
+
in_channels,
|
601 |
+
heads=in_channels // attention_head_dim,
|
602 |
+
dim_head=attention_head_dim,
|
603 |
+
rescale_output_factor=output_scale_factor,
|
604 |
+
eps=resnet_eps,
|
605 |
+
norm_num_groups=attn_groups,
|
606 |
+
spatial_norm_dim=temb_channels
|
607 |
+
if resnet_time_scale_shift == "spatial"
|
608 |
+
else None,
|
609 |
+
residual_connection=True,
|
610 |
+
bias=True,
|
611 |
+
upcast_softmax=True,
|
612 |
+
_from_deprecated_attn_block=True,
|
613 |
+
)
|
614 |
+
)
|
615 |
+
else:
|
616 |
+
attentions.append(None)
|
617 |
+
|
618 |
+
resnets.append(
|
619 |
+
ResnetBlock2D(
|
620 |
+
in_channels=in_channels,
|
621 |
+
out_channels=in_channels,
|
622 |
+
temb_channels=temb_channels,
|
623 |
+
eps=resnet_eps,
|
624 |
+
groups=resnet_groups,
|
625 |
+
dropout=dropout,
|
626 |
+
time_embedding_norm=resnet_time_scale_shift,
|
627 |
+
non_linearity=resnet_act_fn,
|
628 |
+
output_scale_factor=output_scale_factor,
|
629 |
+
pre_norm=resnet_pre_norm,
|
630 |
+
)
|
631 |
+
)
|
632 |
+
|
633 |
+
self.attentions = nn.ModuleList(attentions)
|
634 |
+
self.resnets = nn.ModuleList(resnets)
|
635 |
+
|
636 |
+
def forward(
|
637 |
+
self,
|
638 |
+
hidden_states: torch.FloatTensor,
|
639 |
+
temb: Optional[torch.FloatTensor] = None,
|
640 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
641 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
642 |
+
) -> torch.FloatTensor:
|
643 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
644 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
645 |
+
if attn is not None:
|
646 |
+
hidden_states = attn(
|
647 |
+
hidden_states,
|
648 |
+
temb=temb,
|
649 |
+
self_attn_block_embs=self_attn_block_embs,
|
650 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
651 |
+
)
|
652 |
+
hidden_states = resnet(hidden_states, temb)
|
653 |
+
|
654 |
+
return hidden_states
|
655 |
+
|
656 |
+
|
657 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
658 |
+
def __init__(
|
659 |
+
self,
|
660 |
+
in_channels: int,
|
661 |
+
temb_channels: int,
|
662 |
+
dropout: float = 0.0,
|
663 |
+
num_layers: int = 1,
|
664 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
665 |
+
resnet_eps: float = 1e-6,
|
666 |
+
resnet_time_scale_shift: str = "default",
|
667 |
+
resnet_act_fn: str = "swish",
|
668 |
+
resnet_groups: int = 32,
|
669 |
+
resnet_pre_norm: bool = True,
|
670 |
+
num_attention_heads: int = 1,
|
671 |
+
output_scale_factor: float = 1.0,
|
672 |
+
cross_attention_dim: int = 1280,
|
673 |
+
dual_cross_attention: bool = False,
|
674 |
+
use_linear_projection: bool = False,
|
675 |
+
upcast_attention: bool = False,
|
676 |
+
attention_type: str = "default",
|
677 |
+
):
|
678 |
+
super().__init__()
|
679 |
+
|
680 |
+
self.has_cross_attention = True
|
681 |
+
self.num_attention_heads = num_attention_heads
|
682 |
+
resnet_groups = (
|
683 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
684 |
+
)
|
685 |
+
|
686 |
+
# support for variable transformer layers per block
|
687 |
+
if isinstance(transformer_layers_per_block, int):
|
688 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
689 |
+
|
690 |
+
# there is always at least one resnet
|
691 |
+
resnets = [
|
692 |
+
ResnetBlock2D(
|
693 |
+
in_channels=in_channels,
|
694 |
+
out_channels=in_channels,
|
695 |
+
temb_channels=temb_channels,
|
696 |
+
eps=resnet_eps,
|
697 |
+
groups=resnet_groups,
|
698 |
+
dropout=dropout,
|
699 |
+
time_embedding_norm=resnet_time_scale_shift,
|
700 |
+
non_linearity=resnet_act_fn,
|
701 |
+
output_scale_factor=output_scale_factor,
|
702 |
+
pre_norm=resnet_pre_norm,
|
703 |
+
)
|
704 |
+
]
|
705 |
+
attentions = []
|
706 |
+
|
707 |
+
for i in range(num_layers):
|
708 |
+
if not dual_cross_attention:
|
709 |
+
attentions.append(
|
710 |
+
Transformer2DModel(
|
711 |
+
num_attention_heads,
|
712 |
+
in_channels // num_attention_heads,
|
713 |
+
in_channels=in_channels,
|
714 |
+
num_layers=transformer_layers_per_block[i],
|
715 |
+
cross_attention_dim=cross_attention_dim,
|
716 |
+
norm_num_groups=resnet_groups,
|
717 |
+
use_linear_projection=use_linear_projection,
|
718 |
+
upcast_attention=upcast_attention,
|
719 |
+
attention_type=attention_type,
|
720 |
+
)
|
721 |
+
)
|
722 |
+
else:
|
723 |
+
attentions.append(
|
724 |
+
DualTransformer2DModel(
|
725 |
+
num_attention_heads,
|
726 |
+
in_channels // num_attention_heads,
|
727 |
+
in_channels=in_channels,
|
728 |
+
num_layers=1,
|
729 |
+
cross_attention_dim=cross_attention_dim,
|
730 |
+
norm_num_groups=resnet_groups,
|
731 |
+
)
|
732 |
+
)
|
733 |
+
resnets.append(
|
734 |
+
ResnetBlock2D(
|
735 |
+
in_channels=in_channels,
|
736 |
+
out_channels=in_channels,
|
737 |
+
temb_channels=temb_channels,
|
738 |
+
eps=resnet_eps,
|
739 |
+
groups=resnet_groups,
|
740 |
+
dropout=dropout,
|
741 |
+
time_embedding_norm=resnet_time_scale_shift,
|
742 |
+
non_linearity=resnet_act_fn,
|
743 |
+
output_scale_factor=output_scale_factor,
|
744 |
+
pre_norm=resnet_pre_norm,
|
745 |
+
)
|
746 |
+
)
|
747 |
+
|
748 |
+
self.attentions = nn.ModuleList(attentions)
|
749 |
+
self.resnets = nn.ModuleList(resnets)
|
750 |
+
|
751 |
+
self.gradient_checkpointing = False
|
752 |
+
|
753 |
+
def forward(
|
754 |
+
self,
|
755 |
+
hidden_states: torch.FloatTensor,
|
756 |
+
temb: Optional[torch.FloatTensor] = None,
|
757 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
758 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
759 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
760 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
761 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
762 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
763 |
+
) -> torch.FloatTensor:
|
764 |
+
lora_scale = (
|
765 |
+
cross_attention_kwargs.get("scale", 1.0)
|
766 |
+
if cross_attention_kwargs is not None
|
767 |
+
else 1.0
|
768 |
+
)
|
769 |
+
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
770 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
771 |
+
if self.training and self.gradient_checkpointing:
|
772 |
+
|
773 |
+
def create_custom_forward(module, return_dict=None):
|
774 |
+
def custom_forward(*inputs):
|
775 |
+
if return_dict is not None:
|
776 |
+
return module(*inputs, return_dict=return_dict)
|
777 |
+
else:
|
778 |
+
return module(*inputs)
|
779 |
+
|
780 |
+
return custom_forward
|
781 |
+
|
782 |
+
ckpt_kwargs: Dict[str, Any] = (
|
783 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
784 |
+
)
|
785 |
+
hidden_states = attn(
|
786 |
+
hidden_states,
|
787 |
+
encoder_hidden_states=encoder_hidden_states,
|
788 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
789 |
+
attention_mask=attention_mask,
|
790 |
+
encoder_attention_mask=encoder_attention_mask,
|
791 |
+
return_dict=False,
|
792 |
+
self_attn_block_embs=self_attn_block_embs,
|
793 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
794 |
+
)[0]
|
795 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
796 |
+
create_custom_forward(resnet),
|
797 |
+
hidden_states,
|
798 |
+
temb,
|
799 |
+
**ckpt_kwargs,
|
800 |
+
)
|
801 |
+
else:
|
802 |
+
hidden_states = attn(
|
803 |
+
hidden_states,
|
804 |
+
encoder_hidden_states=encoder_hidden_states,
|
805 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
806 |
+
attention_mask=attention_mask,
|
807 |
+
encoder_attention_mask=encoder_attention_mask,
|
808 |
+
return_dict=False,
|
809 |
+
self_attn_block_embs=self_attn_block_embs,
|
810 |
+
)[0]
|
811 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
812 |
+
|
813 |
+
return hidden_states
|
814 |
+
|
815 |
+
|
816 |
+
class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
817 |
+
def __init__(
|
818 |
+
self,
|
819 |
+
in_channels: int,
|
820 |
+
temb_channels: int,
|
821 |
+
dropout: float = 0.0,
|
822 |
+
num_layers: int = 1,
|
823 |
+
resnet_eps: float = 1e-6,
|
824 |
+
resnet_time_scale_shift: str = "default",
|
825 |
+
resnet_act_fn: str = "swish",
|
826 |
+
resnet_groups: int = 32,
|
827 |
+
resnet_pre_norm: bool = True,
|
828 |
+
attention_head_dim: int = 1,
|
829 |
+
output_scale_factor: float = 1.0,
|
830 |
+
cross_attention_dim: int = 1280,
|
831 |
+
skip_time_act: bool = False,
|
832 |
+
only_cross_attention: bool = False,
|
833 |
+
cross_attention_norm: Optional[str] = None,
|
834 |
+
):
|
835 |
+
super().__init__()
|
836 |
+
|
837 |
+
self.has_cross_attention = True
|
838 |
+
|
839 |
+
self.attention_head_dim = attention_head_dim
|
840 |
+
resnet_groups = (
|
841 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
842 |
+
)
|
843 |
+
|
844 |
+
self.num_heads = in_channels // self.attention_head_dim
|
845 |
+
|
846 |
+
# there is always at least one resnet
|
847 |
+
resnets = [
|
848 |
+
ResnetBlock2D(
|
849 |
+
in_channels=in_channels,
|
850 |
+
out_channels=in_channels,
|
851 |
+
temb_channels=temb_channels,
|
852 |
+
eps=resnet_eps,
|
853 |
+
groups=resnet_groups,
|
854 |
+
dropout=dropout,
|
855 |
+
time_embedding_norm=resnet_time_scale_shift,
|
856 |
+
non_linearity=resnet_act_fn,
|
857 |
+
output_scale_factor=output_scale_factor,
|
858 |
+
pre_norm=resnet_pre_norm,
|
859 |
+
skip_time_act=skip_time_act,
|
860 |
+
)
|
861 |
+
]
|
862 |
+
attentions = []
|
863 |
+
|
864 |
+
for _ in range(num_layers):
|
865 |
+
processor = (
|
866 |
+
AttnAddedKVProcessor2_0()
|
867 |
+
if hasattr(F, "scaled_dot_product_attention")
|
868 |
+
else AttnAddedKVProcessor()
|
869 |
+
)
|
870 |
+
|
871 |
+
attentions.append(
|
872 |
+
Attention(
|
873 |
+
query_dim=in_channels,
|
874 |
+
cross_attention_dim=in_channels,
|
875 |
+
heads=self.num_heads,
|
876 |
+
dim_head=self.attention_head_dim,
|
877 |
+
added_kv_proj_dim=cross_attention_dim,
|
878 |
+
norm_num_groups=resnet_groups,
|
879 |
+
bias=True,
|
880 |
+
upcast_softmax=True,
|
881 |
+
only_cross_attention=only_cross_attention,
|
882 |
+
cross_attention_norm=cross_attention_norm,
|
883 |
+
processor=processor,
|
884 |
+
)
|
885 |
+
)
|
886 |
+
resnets.append(
|
887 |
+
ResnetBlock2D(
|
888 |
+
in_channels=in_channels,
|
889 |
+
out_channels=in_channels,
|
890 |
+
temb_channels=temb_channels,
|
891 |
+
eps=resnet_eps,
|
892 |
+
groups=resnet_groups,
|
893 |
+
dropout=dropout,
|
894 |
+
time_embedding_norm=resnet_time_scale_shift,
|
895 |
+
non_linearity=resnet_act_fn,
|
896 |
+
output_scale_factor=output_scale_factor,
|
897 |
+
pre_norm=resnet_pre_norm,
|
898 |
+
skip_time_act=skip_time_act,
|
899 |
+
)
|
900 |
+
)
|
901 |
+
|
902 |
+
self.attentions = nn.ModuleList(attentions)
|
903 |
+
self.resnets = nn.ModuleList(resnets)
|
904 |
+
|
905 |
+
def forward(
|
906 |
+
self,
|
907 |
+
hidden_states: torch.FloatTensor,
|
908 |
+
temb: Optional[torch.FloatTensor] = None,
|
909 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
910 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
911 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
912 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
913 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
914 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
915 |
+
) -> torch.FloatTensor:
|
916 |
+
cross_attention_kwargs = (
|
917 |
+
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
918 |
+
)
|
919 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
920 |
+
|
921 |
+
if attention_mask is None:
|
922 |
+
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
923 |
+
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
924 |
+
else:
|
925 |
+
# when attention_mask is defined: we don't even check for encoder_attention_mask.
|
926 |
+
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
|
927 |
+
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
|
928 |
+
# then we can simplify this whole if/else block to:
|
929 |
+
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
930 |
+
mask = attention_mask
|
931 |
+
|
932 |
+
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
933 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
934 |
+
# attn
|
935 |
+
hidden_states = attn(
|
936 |
+
hidden_states,
|
937 |
+
encoder_hidden_states=encoder_hidden_states,
|
938 |
+
attention_mask=mask,
|
939 |
+
**cross_attention_kwargs,
|
940 |
+
self_attn_block_embs=self_attn_block_embs,
|
941 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
942 |
+
)
|
943 |
+
|
944 |
+
# resnet
|
945 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
946 |
+
|
947 |
+
return hidden_states
|
948 |
+
|
949 |
+
|
950 |
+
class CrossAttnDownBlock2D(nn.Module):
|
951 |
+
print_idx = 0
|
952 |
+
|
953 |
+
def __init__(
|
954 |
+
self,
|
955 |
+
in_channels: int,
|
956 |
+
out_channels: int,
|
957 |
+
temb_channels: int,
|
958 |
+
dropout: float = 0.0,
|
959 |
+
num_layers: int = 1,
|
960 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
961 |
+
resnet_eps: float = 1e-6,
|
962 |
+
resnet_time_scale_shift: str = "default",
|
963 |
+
resnet_act_fn: str = "swish",
|
964 |
+
resnet_groups: int = 32,
|
965 |
+
resnet_pre_norm: bool = True,
|
966 |
+
num_attention_heads: int = 1,
|
967 |
+
cross_attention_dim: int = 1280,
|
968 |
+
output_scale_factor: float = 1.0,
|
969 |
+
downsample_padding: int = 1,
|
970 |
+
add_downsample: bool = True,
|
971 |
+
dual_cross_attention: bool = False,
|
972 |
+
use_linear_projection: bool = False,
|
973 |
+
only_cross_attention: bool = False,
|
974 |
+
upcast_attention: bool = False,
|
975 |
+
attention_type: str = "default",
|
976 |
+
):
|
977 |
+
super().__init__()
|
978 |
+
resnets = []
|
979 |
+
attentions = []
|
980 |
+
|
981 |
+
self.has_cross_attention = True
|
982 |
+
self.num_attention_heads = num_attention_heads
|
983 |
+
if isinstance(transformer_layers_per_block, int):
|
984 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
985 |
+
|
986 |
+
for i in range(num_layers):
|
987 |
+
in_channels = in_channels if i == 0 else out_channels
|
988 |
+
resnets.append(
|
989 |
+
ResnetBlock2D(
|
990 |
+
in_channels=in_channels,
|
991 |
+
out_channels=out_channels,
|
992 |
+
temb_channels=temb_channels,
|
993 |
+
eps=resnet_eps,
|
994 |
+
groups=resnet_groups,
|
995 |
+
dropout=dropout,
|
996 |
+
time_embedding_norm=resnet_time_scale_shift,
|
997 |
+
non_linearity=resnet_act_fn,
|
998 |
+
output_scale_factor=output_scale_factor,
|
999 |
+
pre_norm=resnet_pre_norm,
|
1000 |
+
)
|
1001 |
+
)
|
1002 |
+
if not dual_cross_attention:
|
1003 |
+
attentions.append(
|
1004 |
+
Transformer2DModel(
|
1005 |
+
num_attention_heads,
|
1006 |
+
out_channels // num_attention_heads,
|
1007 |
+
in_channels=out_channels,
|
1008 |
+
num_layers=transformer_layers_per_block[i],
|
1009 |
+
cross_attention_dim=cross_attention_dim,
|
1010 |
+
norm_num_groups=resnet_groups,
|
1011 |
+
use_linear_projection=use_linear_projection,
|
1012 |
+
only_cross_attention=only_cross_attention,
|
1013 |
+
upcast_attention=upcast_attention,
|
1014 |
+
attention_type=attention_type,
|
1015 |
+
)
|
1016 |
+
)
|
1017 |
+
else:
|
1018 |
+
attentions.append(
|
1019 |
+
DualTransformer2DModel(
|
1020 |
+
num_attention_heads,
|
1021 |
+
out_channels // num_attention_heads,
|
1022 |
+
in_channels=out_channels,
|
1023 |
+
num_layers=1,
|
1024 |
+
cross_attention_dim=cross_attention_dim,
|
1025 |
+
norm_num_groups=resnet_groups,
|
1026 |
+
)
|
1027 |
+
)
|
1028 |
+
self.attentions = nn.ModuleList(attentions)
|
1029 |
+
self.resnets = nn.ModuleList(resnets)
|
1030 |
+
|
1031 |
+
if add_downsample:
|
1032 |
+
self.downsamplers = nn.ModuleList(
|
1033 |
+
[
|
1034 |
+
Downsample2D(
|
1035 |
+
out_channels,
|
1036 |
+
use_conv=True,
|
1037 |
+
out_channels=out_channels,
|
1038 |
+
padding=downsample_padding,
|
1039 |
+
name="op",
|
1040 |
+
)
|
1041 |
+
]
|
1042 |
+
)
|
1043 |
+
else:
|
1044 |
+
self.downsamplers = None
|
1045 |
+
|
1046 |
+
self.gradient_checkpointing = False
|
1047 |
+
|
1048 |
+
def forward(
|
1049 |
+
self,
|
1050 |
+
hidden_states: torch.FloatTensor,
|
1051 |
+
temb: Optional[torch.FloatTensor] = None,
|
1052 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1053 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1054 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1055 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1056 |
+
additional_residuals: Optional[torch.FloatTensor] = None,
|
1057 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
1058 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
1059 |
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1060 |
+
output_states = ()
|
1061 |
+
|
1062 |
+
lora_scale = (
|
1063 |
+
cross_attention_kwargs.get("scale", 1.0)
|
1064 |
+
if cross_attention_kwargs is not None
|
1065 |
+
else 1.0
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
blocks = list(zip(self.resnets, self.attentions))
|
1069 |
+
|
1070 |
+
for i, (resnet, attn) in enumerate(blocks):
|
1071 |
+
if self.training and self.gradient_checkpointing:
|
1072 |
+
|
1073 |
+
def create_custom_forward(module, return_dict=None):
|
1074 |
+
def custom_forward(*inputs):
|
1075 |
+
if return_dict is not None:
|
1076 |
+
return module(*inputs, return_dict=return_dict)
|
1077 |
+
else:
|
1078 |
+
return module(*inputs)
|
1079 |
+
|
1080 |
+
return custom_forward
|
1081 |
+
|
1082 |
+
ckpt_kwargs: Dict[str, Any] = (
|
1083 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1084 |
+
)
|
1085 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1086 |
+
create_custom_forward(resnet),
|
1087 |
+
hidden_states,
|
1088 |
+
temb,
|
1089 |
+
**ckpt_kwargs,
|
1090 |
+
)
|
1091 |
+
if self.print_idx == 0:
|
1092 |
+
logger.debug(f"unet3d after resnet {hidden_states.mean()}")
|
1093 |
+
|
1094 |
+
hidden_states = attn(
|
1095 |
+
hidden_states,
|
1096 |
+
encoder_hidden_states=encoder_hidden_states,
|
1097 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1098 |
+
attention_mask=attention_mask,
|
1099 |
+
encoder_attention_mask=encoder_attention_mask,
|
1100 |
+
return_dict=False,
|
1101 |
+
self_attn_block_embs=self_attn_block_embs,
|
1102 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
1103 |
+
)[0]
|
1104 |
+
else:
|
1105 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1106 |
+
if self.print_idx == 0:
|
1107 |
+
logger.debug(f"unet3d after resnet {hidden_states.mean()}")
|
1108 |
+
hidden_states = attn(
|
1109 |
+
hidden_states,
|
1110 |
+
encoder_hidden_states=encoder_hidden_states,
|
1111 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1112 |
+
attention_mask=attention_mask,
|
1113 |
+
encoder_attention_mask=encoder_attention_mask,
|
1114 |
+
return_dict=False,
|
1115 |
+
self_attn_block_embs=self_attn_block_embs,
|
1116 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
1117 |
+
)[0]
|
1118 |
+
|
1119 |
+
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
1120 |
+
if i == len(blocks) - 1 and additional_residuals is not None:
|
1121 |
+
hidden_states = hidden_states + additional_residuals
|
1122 |
+
|
1123 |
+
output_states = output_states + (hidden_states,)
|
1124 |
+
|
1125 |
+
if self.downsamplers is not None:
|
1126 |
+
for downsampler in self.downsamplers:
|
1127 |
+
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
1128 |
+
|
1129 |
+
output_states = output_states + (hidden_states,)
|
1130 |
+
|
1131 |
+
self.print_idx += 1
|
1132 |
+
return hidden_states, output_states
|
1133 |
+
|
1134 |
+
|
1135 |
+
class DownBlock2D(nn.Module):
|
1136 |
+
def __init__(
|
1137 |
+
self,
|
1138 |
+
in_channels: int,
|
1139 |
+
out_channels: int,
|
1140 |
+
temb_channels: int,
|
1141 |
+
dropout: float = 0.0,
|
1142 |
+
num_layers: int = 1,
|
1143 |
+
resnet_eps: float = 1e-6,
|
1144 |
+
resnet_time_scale_shift: str = "default",
|
1145 |
+
resnet_act_fn: str = "swish",
|
1146 |
+
resnet_groups: int = 32,
|
1147 |
+
resnet_pre_norm: bool = True,
|
1148 |
+
output_scale_factor: float = 1.0,
|
1149 |
+
add_downsample: bool = True,
|
1150 |
+
downsample_padding: int = 1,
|
1151 |
+
):
|
1152 |
+
super().__init__()
|
1153 |
+
resnets = []
|
1154 |
+
|
1155 |
+
for i in range(num_layers):
|
1156 |
+
in_channels = in_channels if i == 0 else out_channels
|
1157 |
+
resnets.append(
|
1158 |
+
ResnetBlock2D(
|
1159 |
+
in_channels=in_channels,
|
1160 |
+
out_channels=out_channels,
|
1161 |
+
temb_channels=temb_channels,
|
1162 |
+
eps=resnet_eps,
|
1163 |
+
groups=resnet_groups,
|
1164 |
+
dropout=dropout,
|
1165 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1166 |
+
non_linearity=resnet_act_fn,
|
1167 |
+
output_scale_factor=output_scale_factor,
|
1168 |
+
pre_norm=resnet_pre_norm,
|
1169 |
+
)
|
1170 |
+
)
|
1171 |
+
|
1172 |
+
self.resnets = nn.ModuleList(resnets)
|
1173 |
+
|
1174 |
+
if add_downsample:
|
1175 |
+
self.downsamplers = nn.ModuleList(
|
1176 |
+
[
|
1177 |
+
Downsample2D(
|
1178 |
+
out_channels,
|
1179 |
+
use_conv=True,
|
1180 |
+
out_channels=out_channels,
|
1181 |
+
padding=downsample_padding,
|
1182 |
+
name="op",
|
1183 |
+
)
|
1184 |
+
]
|
1185 |
+
)
|
1186 |
+
else:
|
1187 |
+
self.downsamplers = None
|
1188 |
+
|
1189 |
+
self.gradient_checkpointing = False
|
1190 |
+
|
1191 |
+
def forward(
|
1192 |
+
self,
|
1193 |
+
hidden_states: torch.FloatTensor,
|
1194 |
+
temb: Optional[torch.FloatTensor] = None,
|
1195 |
+
scale: float = 1.0,
|
1196 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
1197 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
1198 |
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1199 |
+
output_states = ()
|
1200 |
+
|
1201 |
+
for resnet in self.resnets:
|
1202 |
+
if self.training and self.gradient_checkpointing:
|
1203 |
+
|
1204 |
+
def create_custom_forward(module):
|
1205 |
+
def custom_forward(*inputs):
|
1206 |
+
return module(*inputs)
|
1207 |
+
|
1208 |
+
return custom_forward
|
1209 |
+
|
1210 |
+
if is_torch_version(">=", "1.11.0"):
|
1211 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1212 |
+
create_custom_forward(resnet),
|
1213 |
+
hidden_states,
|
1214 |
+
temb,
|
1215 |
+
use_reentrant=False,
|
1216 |
+
)
|
1217 |
+
else:
|
1218 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1219 |
+
create_custom_forward(resnet), hidden_states, temb
|
1220 |
+
)
|
1221 |
+
else:
|
1222 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1223 |
+
|
1224 |
+
output_states = output_states + (hidden_states,)
|
1225 |
+
|
1226 |
+
if self.downsamplers is not None:
|
1227 |
+
for downsampler in self.downsamplers:
|
1228 |
+
hidden_states = downsampler(hidden_states, scale=scale)
|
1229 |
+
|
1230 |
+
output_states = output_states + (hidden_states,)
|
1231 |
+
|
1232 |
+
return hidden_states, output_states
|
1233 |
+
|
1234 |
+
|
1235 |
+
class CrossAttnUpBlock2D(nn.Module):
|
1236 |
+
def __init__(
|
1237 |
+
self,
|
1238 |
+
in_channels: int,
|
1239 |
+
out_channels: int,
|
1240 |
+
prev_output_channel: int,
|
1241 |
+
temb_channels: int,
|
1242 |
+
resolution_idx: Optional[int] = None,
|
1243 |
+
dropout: float = 0.0,
|
1244 |
+
num_layers: int = 1,
|
1245 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
1246 |
+
resnet_eps: float = 1e-6,
|
1247 |
+
resnet_time_scale_shift: str = "default",
|
1248 |
+
resnet_act_fn: str = "swish",
|
1249 |
+
resnet_groups: int = 32,
|
1250 |
+
resnet_pre_norm: bool = True,
|
1251 |
+
num_attention_heads: int = 1,
|
1252 |
+
cross_attention_dim: int = 1280,
|
1253 |
+
output_scale_factor: float = 1.0,
|
1254 |
+
add_upsample: bool = True,
|
1255 |
+
dual_cross_attention: bool = False,
|
1256 |
+
use_linear_projection: bool = False,
|
1257 |
+
only_cross_attention: bool = False,
|
1258 |
+
upcast_attention: bool = False,
|
1259 |
+
attention_type: str = "default",
|
1260 |
+
):
|
1261 |
+
super().__init__()
|
1262 |
+
resnets = []
|
1263 |
+
attentions = []
|
1264 |
+
|
1265 |
+
self.has_cross_attention = True
|
1266 |
+
self.num_attention_heads = num_attention_heads
|
1267 |
+
|
1268 |
+
if isinstance(transformer_layers_per_block, int):
|
1269 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
1270 |
+
|
1271 |
+
for i in range(num_layers):
|
1272 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1273 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1274 |
+
|
1275 |
+
resnets.append(
|
1276 |
+
ResnetBlock2D(
|
1277 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1278 |
+
out_channels=out_channels,
|
1279 |
+
temb_channels=temb_channels,
|
1280 |
+
eps=resnet_eps,
|
1281 |
+
groups=resnet_groups,
|
1282 |
+
dropout=dropout,
|
1283 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1284 |
+
non_linearity=resnet_act_fn,
|
1285 |
+
output_scale_factor=output_scale_factor,
|
1286 |
+
pre_norm=resnet_pre_norm,
|
1287 |
+
)
|
1288 |
+
)
|
1289 |
+
if not dual_cross_attention:
|
1290 |
+
attentions.append(
|
1291 |
+
Transformer2DModel(
|
1292 |
+
num_attention_heads,
|
1293 |
+
out_channels // num_attention_heads,
|
1294 |
+
in_channels=out_channels,
|
1295 |
+
num_layers=transformer_layers_per_block[i],
|
1296 |
+
cross_attention_dim=cross_attention_dim,
|
1297 |
+
norm_num_groups=resnet_groups,
|
1298 |
+
use_linear_projection=use_linear_projection,
|
1299 |
+
only_cross_attention=only_cross_attention,
|
1300 |
+
upcast_attention=upcast_attention,
|
1301 |
+
attention_type=attention_type,
|
1302 |
+
)
|
1303 |
+
)
|
1304 |
+
else:
|
1305 |
+
attentions.append(
|
1306 |
+
DualTransformer2DModel(
|
1307 |
+
num_attention_heads,
|
1308 |
+
out_channels // num_attention_heads,
|
1309 |
+
in_channels=out_channels,
|
1310 |
+
num_layers=1,
|
1311 |
+
cross_attention_dim=cross_attention_dim,
|
1312 |
+
norm_num_groups=resnet_groups,
|
1313 |
+
)
|
1314 |
+
)
|
1315 |
+
self.attentions = nn.ModuleList(attentions)
|
1316 |
+
self.resnets = nn.ModuleList(resnets)
|
1317 |
+
|
1318 |
+
if add_upsample:
|
1319 |
+
self.upsamplers = nn.ModuleList(
|
1320 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
|
1321 |
+
)
|
1322 |
+
else:
|
1323 |
+
self.upsamplers = None
|
1324 |
+
|
1325 |
+
self.gradient_checkpointing = False
|
1326 |
+
self.resolution_idx = resolution_idx
|
1327 |
+
|
1328 |
+
def forward(
|
1329 |
+
self,
|
1330 |
+
hidden_states: torch.FloatTensor,
|
1331 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1332 |
+
temb: Optional[torch.FloatTensor] = None,
|
1333 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1334 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1335 |
+
upsample_size: Optional[int] = None,
|
1336 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1337 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1338 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
1339 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
1340 |
+
) -> torch.FloatTensor:
|
1341 |
+
lora_scale = (
|
1342 |
+
cross_attention_kwargs.get("scale", 1.0)
|
1343 |
+
if cross_attention_kwargs is not None
|
1344 |
+
else 1.0
|
1345 |
+
)
|
1346 |
+
is_freeu_enabled = (
|
1347 |
+
getattr(self, "s1", None)
|
1348 |
+
and getattr(self, "s2", None)
|
1349 |
+
and getattr(self, "b1", None)
|
1350 |
+
and getattr(self, "b2", None)
|
1351 |
+
)
|
1352 |
+
|
1353 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1354 |
+
# pop res hidden states
|
1355 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1356 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1357 |
+
|
1358 |
+
# FreeU: Only operate on the first two stages
|
1359 |
+
if is_freeu_enabled:
|
1360 |
+
hidden_states, res_hidden_states = apply_freeu(
|
1361 |
+
self.resolution_idx,
|
1362 |
+
hidden_states,
|
1363 |
+
res_hidden_states,
|
1364 |
+
s1=self.s1,
|
1365 |
+
s2=self.s2,
|
1366 |
+
b1=self.b1,
|
1367 |
+
b2=self.b2,
|
1368 |
+
)
|
1369 |
+
|
1370 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1371 |
+
|
1372 |
+
if self.training and self.gradient_checkpointing:
|
1373 |
+
|
1374 |
+
def create_custom_forward(module, return_dict=None):
|
1375 |
+
def custom_forward(*inputs):
|
1376 |
+
if return_dict is not None:
|
1377 |
+
return module(*inputs, return_dict=return_dict)
|
1378 |
+
else:
|
1379 |
+
return module(*inputs)
|
1380 |
+
|
1381 |
+
return custom_forward
|
1382 |
+
|
1383 |
+
ckpt_kwargs: Dict[str, Any] = (
|
1384 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1385 |
+
)
|
1386 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1387 |
+
create_custom_forward(resnet),
|
1388 |
+
hidden_states,
|
1389 |
+
temb,
|
1390 |
+
**ckpt_kwargs,
|
1391 |
+
)
|
1392 |
+
hidden_states = attn(
|
1393 |
+
hidden_states,
|
1394 |
+
encoder_hidden_states=encoder_hidden_states,
|
1395 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1396 |
+
attention_mask=attention_mask,
|
1397 |
+
encoder_attention_mask=encoder_attention_mask,
|
1398 |
+
return_dict=False,
|
1399 |
+
self_attn_block_embs=self_attn_block_embs,
|
1400 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
1401 |
+
)[0]
|
1402 |
+
else:
|
1403 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1404 |
+
hidden_states = attn(
|
1405 |
+
hidden_states,
|
1406 |
+
encoder_hidden_states=encoder_hidden_states,
|
1407 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1408 |
+
attention_mask=attention_mask,
|
1409 |
+
encoder_attention_mask=encoder_attention_mask,
|
1410 |
+
return_dict=False,
|
1411 |
+
self_attn_block_embs=self_attn_block_embs,
|
1412 |
+
)[0]
|
1413 |
+
|
1414 |
+
if self.upsamplers is not None:
|
1415 |
+
for upsampler in self.upsamplers:
|
1416 |
+
hidden_states = upsampler(
|
1417 |
+
hidden_states, upsample_size, scale=lora_scale
|
1418 |
+
)
|
1419 |
+
|
1420 |
+
return hidden_states
|
1421 |
+
|
1422 |
+
|
1423 |
+
class UpBlock2D(nn.Module):
|
1424 |
+
def __init__(
|
1425 |
+
self,
|
1426 |
+
in_channels: int,
|
1427 |
+
prev_output_channel: int,
|
1428 |
+
out_channels: int,
|
1429 |
+
temb_channels: int,
|
1430 |
+
resolution_idx: Optional[int] = None,
|
1431 |
+
dropout: float = 0.0,
|
1432 |
+
num_layers: int = 1,
|
1433 |
+
resnet_eps: float = 1e-6,
|
1434 |
+
resnet_time_scale_shift: str = "default",
|
1435 |
+
resnet_act_fn: str = "swish",
|
1436 |
+
resnet_groups: int = 32,
|
1437 |
+
resnet_pre_norm: bool = True,
|
1438 |
+
output_scale_factor: float = 1.0,
|
1439 |
+
add_upsample: bool = True,
|
1440 |
+
):
|
1441 |
+
super().__init__()
|
1442 |
+
resnets = []
|
1443 |
+
|
1444 |
+
for i in range(num_layers):
|
1445 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1446 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1447 |
+
|
1448 |
+
resnets.append(
|
1449 |
+
ResnetBlock2D(
|
1450 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1451 |
+
out_channels=out_channels,
|
1452 |
+
temb_channels=temb_channels,
|
1453 |
+
eps=resnet_eps,
|
1454 |
+
groups=resnet_groups,
|
1455 |
+
dropout=dropout,
|
1456 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1457 |
+
non_linearity=resnet_act_fn,
|
1458 |
+
output_scale_factor=output_scale_factor,
|
1459 |
+
pre_norm=resnet_pre_norm,
|
1460 |
+
)
|
1461 |
+
)
|
1462 |
+
|
1463 |
+
self.resnets = nn.ModuleList(resnets)
|
1464 |
+
|
1465 |
+
if add_upsample:
|
1466 |
+
self.upsamplers = nn.ModuleList(
|
1467 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
|
1468 |
+
)
|
1469 |
+
else:
|
1470 |
+
self.upsamplers = None
|
1471 |
+
|
1472 |
+
self.gradient_checkpointing = False
|
1473 |
+
self.resolution_idx = resolution_idx
|
1474 |
+
|
1475 |
+
def forward(
|
1476 |
+
self,
|
1477 |
+
hidden_states: torch.FloatTensor,
|
1478 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1479 |
+
temb: Optional[torch.FloatTensor] = None,
|
1480 |
+
upsample_size: Optional[int] = None,
|
1481 |
+
scale: float = 1.0,
|
1482 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
1483 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
1484 |
+
) -> torch.FloatTensor:
|
1485 |
+
is_freeu_enabled = (
|
1486 |
+
getattr(self, "s1", None)
|
1487 |
+
and getattr(self, "s2", None)
|
1488 |
+
and getattr(self, "b1", None)
|
1489 |
+
and getattr(self, "b2", None)
|
1490 |
+
)
|
1491 |
+
|
1492 |
+
for resnet in self.resnets:
|
1493 |
+
# pop res hidden states
|
1494 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1495 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1496 |
+
|
1497 |
+
# FreeU: Only operate on the first two stages
|
1498 |
+
if is_freeu_enabled:
|
1499 |
+
hidden_states, res_hidden_states = apply_freeu(
|
1500 |
+
self.resolution_idx,
|
1501 |
+
hidden_states,
|
1502 |
+
res_hidden_states,
|
1503 |
+
s1=self.s1,
|
1504 |
+
s2=self.s2,
|
1505 |
+
b1=self.b1,
|
1506 |
+
b2=self.b2,
|
1507 |
+
)
|
1508 |
+
|
1509 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1510 |
+
|
1511 |
+
if self.training and self.gradient_checkpointing:
|
1512 |
+
|
1513 |
+
def create_custom_forward(module):
|
1514 |
+
def custom_forward(*inputs):
|
1515 |
+
return module(*inputs)
|
1516 |
+
|
1517 |
+
return custom_forward
|
1518 |
+
|
1519 |
+
if is_torch_version(">=", "1.11.0"):
|
1520 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1521 |
+
create_custom_forward(resnet),
|
1522 |
+
hidden_states,
|
1523 |
+
temb,
|
1524 |
+
use_reentrant=False,
|
1525 |
+
)
|
1526 |
+
else:
|
1527 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1528 |
+
create_custom_forward(resnet), hidden_states, temb
|
1529 |
+
)
|
1530 |
+
else:
|
1531 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1532 |
+
|
1533 |
+
if self.upsamplers is not None:
|
1534 |
+
for upsampler in self.upsamplers:
|
1535 |
+
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
1536 |
+
|
1537 |
+
return hidden_states
|
musev/models/unet_3d_blocks.py
ADDED
@@ -0,0 +1,1413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/unet_3d_blocks.py
|
16 |
+
|
17 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
18 |
+
import logging
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
from diffusers.utils import is_torch_version
|
24 |
+
from diffusers.models.transformer_2d import (
|
25 |
+
Transformer2DModel as DiffusersTransformer2DModel,
|
26 |
+
)
|
27 |
+
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
28 |
+
from ..data.data_util import batch_adain_conditioned_tensor
|
29 |
+
|
30 |
+
from .resnet import TemporalConvLayer
|
31 |
+
from .temporal_transformer import TransformerTemporalModel
|
32 |
+
from .transformer_2d import Transformer2DModel
|
33 |
+
from .attention_processor import ReferEmbFuseAttention
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
# 注:
|
39 |
+
# (1) 原代码的`use_linear_projection`默认值均为True,与2D-SD模型不符,load时报错。因此均改为False
|
40 |
+
# (2) 原代码调用`Transformer2DModel`的输入参数顺序为n_channels // attn_num_head_channels, attn_num_head_channels,
|
41 |
+
# 与2D-SD模型不符。因此把顺序交换
|
42 |
+
# (3) 增加了temporal attention用的frame embedding输入
|
43 |
+
|
44 |
+
# note:
|
45 |
+
# 1. The default value of `use_linear_projection` in the original code is True, which is inconsistent with the 2D-SD model and causes an error when loading. Therefore, it is changed to False.
|
46 |
+
# 2. The original code calls `Transformer2DModel` with the input parameter order of n_channels // attn_num_head_channels, attn_num_head_channels, which is inconsistent with the 2D-SD model. Therefore, the order is reversed.
|
47 |
+
# 3. Added the frame embedding input used by the temporal attention
|
48 |
+
|
49 |
+
|
50 |
+
def get_down_block(
|
51 |
+
down_block_type,
|
52 |
+
num_layers,
|
53 |
+
in_channels,
|
54 |
+
out_channels,
|
55 |
+
temb_channels,
|
56 |
+
femb_channels,
|
57 |
+
add_downsample,
|
58 |
+
resnet_eps,
|
59 |
+
resnet_act_fn,
|
60 |
+
attn_num_head_channels,
|
61 |
+
resnet_groups=None,
|
62 |
+
cross_attention_dim=None,
|
63 |
+
downsample_padding=None,
|
64 |
+
dual_cross_attention=False,
|
65 |
+
use_linear_projection=False,
|
66 |
+
only_cross_attention=False,
|
67 |
+
upcast_attention=False,
|
68 |
+
resnet_time_scale_shift="default",
|
69 |
+
temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
|
70 |
+
temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
|
71 |
+
need_spatial_position_emb: bool = False,
|
72 |
+
need_t2i_ip_adapter: bool = False,
|
73 |
+
ip_adapter_cross_attn: bool = False,
|
74 |
+
need_t2i_facein: bool = False,
|
75 |
+
need_t2i_ip_adapter_face: bool = False,
|
76 |
+
need_adain_temporal_cond: bool = False,
|
77 |
+
resnet_2d_skip_time_act: bool = False,
|
78 |
+
need_refer_emb: bool = False,
|
79 |
+
):
|
80 |
+
if (isinstance(down_block_type, str) and down_block_type == "DownBlock3D") or (
|
81 |
+
isinstance(down_block_type, nn.Module)
|
82 |
+
and down_block_type.__name__ == "DownBlock3D"
|
83 |
+
):
|
84 |
+
return DownBlock3D(
|
85 |
+
num_layers=num_layers,
|
86 |
+
in_channels=in_channels,
|
87 |
+
out_channels=out_channels,
|
88 |
+
temb_channels=temb_channels,
|
89 |
+
femb_channels=femb_channels,
|
90 |
+
add_downsample=add_downsample,
|
91 |
+
resnet_eps=resnet_eps,
|
92 |
+
resnet_act_fn=resnet_act_fn,
|
93 |
+
resnet_groups=resnet_groups,
|
94 |
+
downsample_padding=downsample_padding,
|
95 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
96 |
+
temporal_conv_block=temporal_conv_block,
|
97 |
+
need_adain_temporal_cond=need_adain_temporal_cond,
|
98 |
+
resnet_2d_skip_time_act=resnet_2d_skip_time_act,
|
99 |
+
need_refer_emb=need_refer_emb,
|
100 |
+
attn_num_head_channels=attn_num_head_channels,
|
101 |
+
)
|
102 |
+
elif (
|
103 |
+
isinstance(down_block_type, str) and down_block_type == "CrossAttnDownBlock3D"
|
104 |
+
) or (
|
105 |
+
isinstance(down_block_type, nn.Module)
|
106 |
+
and down_block_type.__name__ == "CrossAttnDownBlock3D"
|
107 |
+
):
|
108 |
+
if cross_attention_dim is None:
|
109 |
+
raise ValueError(
|
110 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock3D"
|
111 |
+
)
|
112 |
+
return CrossAttnDownBlock3D(
|
113 |
+
num_layers=num_layers,
|
114 |
+
in_channels=in_channels,
|
115 |
+
out_channels=out_channels,
|
116 |
+
temb_channels=temb_channels,
|
117 |
+
femb_channels=femb_channels,
|
118 |
+
add_downsample=add_downsample,
|
119 |
+
resnet_eps=resnet_eps,
|
120 |
+
resnet_act_fn=resnet_act_fn,
|
121 |
+
resnet_groups=resnet_groups,
|
122 |
+
downsample_padding=downsample_padding,
|
123 |
+
cross_attention_dim=cross_attention_dim,
|
124 |
+
attn_num_head_channels=attn_num_head_channels,
|
125 |
+
dual_cross_attention=dual_cross_attention,
|
126 |
+
use_linear_projection=use_linear_projection,
|
127 |
+
only_cross_attention=only_cross_attention,
|
128 |
+
upcast_attention=upcast_attention,
|
129 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
130 |
+
temporal_conv_block=temporal_conv_block,
|
131 |
+
temporal_transformer=temporal_transformer,
|
132 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
133 |
+
need_t2i_ip_adapter=need_t2i_ip_adapter,
|
134 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
135 |
+
need_t2i_facein=need_t2i_facein,
|
136 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
137 |
+
need_adain_temporal_cond=need_adain_temporal_cond,
|
138 |
+
resnet_2d_skip_time_act=resnet_2d_skip_time_act,
|
139 |
+
need_refer_emb=need_refer_emb,
|
140 |
+
)
|
141 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
142 |
+
|
143 |
+
|
144 |
+
def get_up_block(
|
145 |
+
up_block_type,
|
146 |
+
num_layers,
|
147 |
+
in_channels,
|
148 |
+
out_channels,
|
149 |
+
prev_output_channel,
|
150 |
+
temb_channels,
|
151 |
+
femb_channels,
|
152 |
+
add_upsample,
|
153 |
+
resnet_eps,
|
154 |
+
resnet_act_fn,
|
155 |
+
attn_num_head_channels,
|
156 |
+
resnet_groups=None,
|
157 |
+
cross_attention_dim=None,
|
158 |
+
dual_cross_attention=False,
|
159 |
+
use_linear_projection=False,
|
160 |
+
only_cross_attention=False,
|
161 |
+
upcast_attention=False,
|
162 |
+
resnet_time_scale_shift="default",
|
163 |
+
temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
|
164 |
+
temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
|
165 |
+
need_spatial_position_emb: bool = False,
|
166 |
+
need_t2i_ip_adapter: bool = False,
|
167 |
+
ip_adapter_cross_attn: bool = False,
|
168 |
+
need_t2i_facein: bool = False,
|
169 |
+
need_t2i_ip_adapter_face: bool = False,
|
170 |
+
need_adain_temporal_cond: bool = False,
|
171 |
+
resnet_2d_skip_time_act: bool = False,
|
172 |
+
):
|
173 |
+
if (isinstance(up_block_type, str) and up_block_type == "UpBlock3D") or (
|
174 |
+
isinstance(up_block_type, nn.Module) and up_block_type.__name__ == "UpBlock3D"
|
175 |
+
):
|
176 |
+
return UpBlock3D(
|
177 |
+
num_layers=num_layers,
|
178 |
+
in_channels=in_channels,
|
179 |
+
out_channels=out_channels,
|
180 |
+
prev_output_channel=prev_output_channel,
|
181 |
+
temb_channels=temb_channels,
|
182 |
+
femb_channels=femb_channels,
|
183 |
+
add_upsample=add_upsample,
|
184 |
+
resnet_eps=resnet_eps,
|
185 |
+
resnet_act_fn=resnet_act_fn,
|
186 |
+
resnet_groups=resnet_groups,
|
187 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
188 |
+
temporal_conv_block=temporal_conv_block,
|
189 |
+
need_adain_temporal_cond=need_adain_temporal_cond,
|
190 |
+
resnet_2d_skip_time_act=resnet_2d_skip_time_act,
|
191 |
+
)
|
192 |
+
elif (isinstance(up_block_type, str) and up_block_type == "CrossAttnUpBlock3D") or (
|
193 |
+
isinstance(up_block_type, nn.Module)
|
194 |
+
and up_block_type.__name__ == "CrossAttnUpBlock3D"
|
195 |
+
):
|
196 |
+
if cross_attention_dim is None:
|
197 |
+
raise ValueError(
|
198 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock3D"
|
199 |
+
)
|
200 |
+
return CrossAttnUpBlock3D(
|
201 |
+
num_layers=num_layers,
|
202 |
+
in_channels=in_channels,
|
203 |
+
out_channels=out_channels,
|
204 |
+
prev_output_channel=prev_output_channel,
|
205 |
+
temb_channels=temb_channels,
|
206 |
+
femb_channels=femb_channels,
|
207 |
+
add_upsample=add_upsample,
|
208 |
+
resnet_eps=resnet_eps,
|
209 |
+
resnet_act_fn=resnet_act_fn,
|
210 |
+
resnet_groups=resnet_groups,
|
211 |
+
cross_attention_dim=cross_attention_dim,
|
212 |
+
attn_num_head_channels=attn_num_head_channels,
|
213 |
+
dual_cross_attention=dual_cross_attention,
|
214 |
+
use_linear_projection=use_linear_projection,
|
215 |
+
only_cross_attention=only_cross_attention,
|
216 |
+
upcast_attention=upcast_attention,
|
217 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
218 |
+
temporal_conv_block=temporal_conv_block,
|
219 |
+
temporal_transformer=temporal_transformer,
|
220 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
221 |
+
need_t2i_ip_adapter=need_t2i_ip_adapter,
|
222 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
223 |
+
need_t2i_facein=need_t2i_facein,
|
224 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
225 |
+
need_adain_temporal_cond=need_adain_temporal_cond,
|
226 |
+
resnet_2d_skip_time_act=resnet_2d_skip_time_act,
|
227 |
+
)
|
228 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
229 |
+
|
230 |
+
|
231 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
232 |
+
print_idx = 0
|
233 |
+
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
in_channels: int,
|
237 |
+
temb_channels: int,
|
238 |
+
femb_channels: int,
|
239 |
+
dropout: float = 0.0,
|
240 |
+
num_layers: int = 1,
|
241 |
+
resnet_eps: float = 1e-6,
|
242 |
+
resnet_time_scale_shift: str = "default",
|
243 |
+
resnet_act_fn: str = "swish",
|
244 |
+
resnet_groups: int = 32,
|
245 |
+
resnet_pre_norm: bool = True,
|
246 |
+
attn_num_head_channels=1,
|
247 |
+
output_scale_factor=1.0,
|
248 |
+
cross_attention_dim=1280,
|
249 |
+
dual_cross_attention=False,
|
250 |
+
use_linear_projection=False,
|
251 |
+
upcast_attention=False,
|
252 |
+
temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
|
253 |
+
temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
|
254 |
+
need_spatial_position_emb: bool = False,
|
255 |
+
need_t2i_ip_adapter: bool = False,
|
256 |
+
ip_adapter_cross_attn: bool = False,
|
257 |
+
need_t2i_facein: bool = False,
|
258 |
+
need_t2i_ip_adapter_face: bool = False,
|
259 |
+
need_adain_temporal_cond: bool = False,
|
260 |
+
resnet_2d_skip_time_act: bool = False,
|
261 |
+
):
|
262 |
+
super().__init__()
|
263 |
+
|
264 |
+
self.has_cross_attention = True
|
265 |
+
self.attn_num_head_channels = attn_num_head_channels
|
266 |
+
resnet_groups = (
|
267 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
268 |
+
)
|
269 |
+
|
270 |
+
# there is always at least one resnet
|
271 |
+
resnets = [
|
272 |
+
ResnetBlock2D(
|
273 |
+
in_channels=in_channels,
|
274 |
+
out_channels=in_channels,
|
275 |
+
temb_channels=temb_channels,
|
276 |
+
eps=resnet_eps,
|
277 |
+
groups=resnet_groups,
|
278 |
+
dropout=dropout,
|
279 |
+
time_embedding_norm=resnet_time_scale_shift,
|
280 |
+
non_linearity=resnet_act_fn,
|
281 |
+
output_scale_factor=output_scale_factor,
|
282 |
+
pre_norm=resnet_pre_norm,
|
283 |
+
skip_time_act=resnet_2d_skip_time_act,
|
284 |
+
)
|
285 |
+
]
|
286 |
+
if temporal_conv_block is not None:
|
287 |
+
temp_convs = [
|
288 |
+
temporal_conv_block(
|
289 |
+
in_channels,
|
290 |
+
in_channels,
|
291 |
+
dropout=0.1,
|
292 |
+
femb_channels=femb_channels,
|
293 |
+
)
|
294 |
+
]
|
295 |
+
else:
|
296 |
+
temp_convs = [None]
|
297 |
+
attentions = []
|
298 |
+
temp_attentions = []
|
299 |
+
|
300 |
+
for _ in range(num_layers):
|
301 |
+
attentions.append(
|
302 |
+
Transformer2DModel(
|
303 |
+
attn_num_head_channels,
|
304 |
+
in_channels // attn_num_head_channels,
|
305 |
+
in_channels=in_channels,
|
306 |
+
num_layers=1,
|
307 |
+
cross_attention_dim=cross_attention_dim,
|
308 |
+
norm_num_groups=resnet_groups,
|
309 |
+
use_linear_projection=use_linear_projection,
|
310 |
+
upcast_attention=upcast_attention,
|
311 |
+
cross_attn_temporal_cond=need_t2i_ip_adapter,
|
312 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
313 |
+
need_t2i_facein=need_t2i_facein,
|
314 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
315 |
+
)
|
316 |
+
)
|
317 |
+
if temporal_transformer is not None:
|
318 |
+
temp_attention = temporal_transformer(
|
319 |
+
attn_num_head_channels,
|
320 |
+
in_channels // attn_num_head_channels,
|
321 |
+
in_channels=in_channels,
|
322 |
+
num_layers=1,
|
323 |
+
femb_channels=femb_channels,
|
324 |
+
cross_attention_dim=cross_attention_dim,
|
325 |
+
norm_num_groups=resnet_groups,
|
326 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
temp_attention = None
|
330 |
+
temp_attentions.append(temp_attention)
|
331 |
+
resnets.append(
|
332 |
+
ResnetBlock2D(
|
333 |
+
in_channels=in_channels,
|
334 |
+
out_channels=in_channels,
|
335 |
+
temb_channels=temb_channels,
|
336 |
+
eps=resnet_eps,
|
337 |
+
groups=resnet_groups,
|
338 |
+
dropout=dropout,
|
339 |
+
time_embedding_norm=resnet_time_scale_shift,
|
340 |
+
non_linearity=resnet_act_fn,
|
341 |
+
output_scale_factor=output_scale_factor,
|
342 |
+
pre_norm=resnet_pre_norm,
|
343 |
+
skip_time_act=resnet_2d_skip_time_act,
|
344 |
+
)
|
345 |
+
)
|
346 |
+
if temporal_conv_block is not None:
|
347 |
+
temp_convs.append(
|
348 |
+
temporal_conv_block(
|
349 |
+
in_channels,
|
350 |
+
in_channels,
|
351 |
+
dropout=0.1,
|
352 |
+
femb_channels=femb_channels,
|
353 |
+
)
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
temp_convs.append(None)
|
357 |
+
|
358 |
+
self.resnets = nn.ModuleList(resnets)
|
359 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
360 |
+
self.attentions = nn.ModuleList(attentions)
|
361 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
362 |
+
self.need_adain_temporal_cond = need_adain_temporal_cond
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
hidden_states,
|
367 |
+
temb=None,
|
368 |
+
femb=None,
|
369 |
+
encoder_hidden_states=None,
|
370 |
+
attention_mask=None,
|
371 |
+
num_frames=1,
|
372 |
+
cross_attention_kwargs=None,
|
373 |
+
sample_index: torch.LongTensor = None,
|
374 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
375 |
+
spatial_position_emb: torch.Tensor = None,
|
376 |
+
refer_self_attn_emb: List[torch.Tensor] = None,
|
377 |
+
refer_self_attn_emb_mode: Literal["read", "write"] = "read",
|
378 |
+
):
|
379 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
380 |
+
if self.temp_convs[0] is not None:
|
381 |
+
hidden_states = self.temp_convs[0](
|
382 |
+
hidden_states,
|
383 |
+
femb=femb,
|
384 |
+
num_frames=num_frames,
|
385 |
+
sample_index=sample_index,
|
386 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
387 |
+
)
|
388 |
+
for attn, temp_attn, resnet, temp_conv in zip(
|
389 |
+
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
|
390 |
+
):
|
391 |
+
hidden_states = attn(
|
392 |
+
hidden_states,
|
393 |
+
encoder_hidden_states=encoder_hidden_states,
|
394 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
395 |
+
self_attn_block_embs=refer_self_attn_emb,
|
396 |
+
self_attn_block_embs_mode=refer_self_attn_emb_mode,
|
397 |
+
).sample
|
398 |
+
if temp_attn is not None:
|
399 |
+
hidden_states = temp_attn(
|
400 |
+
hidden_states,
|
401 |
+
femb=femb,
|
402 |
+
num_frames=num_frames,
|
403 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
404 |
+
encoder_hidden_states=encoder_hidden_states,
|
405 |
+
sample_index=sample_index,
|
406 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
407 |
+
spatial_position_emb=spatial_position_emb,
|
408 |
+
).sample
|
409 |
+
hidden_states = resnet(hidden_states, temb)
|
410 |
+
if temp_conv is not None:
|
411 |
+
hidden_states = temp_conv(
|
412 |
+
hidden_states,
|
413 |
+
femb=femb,
|
414 |
+
num_frames=num_frames,
|
415 |
+
sample_index=sample_index,
|
416 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
417 |
+
)
|
418 |
+
if (
|
419 |
+
self.need_adain_temporal_cond
|
420 |
+
and num_frames > 1
|
421 |
+
and sample_index is not None
|
422 |
+
):
|
423 |
+
if self.print_idx == 0:
|
424 |
+
logger.debug(f"adain to vision_condition")
|
425 |
+
hidden_states = batch_adain_conditioned_tensor(
|
426 |
+
hidden_states,
|
427 |
+
num_frames=num_frames,
|
428 |
+
need_style_fidelity=False,
|
429 |
+
src_index=sample_index,
|
430 |
+
dst_index=vision_conditon_frames_sample_index,
|
431 |
+
)
|
432 |
+
self.print_idx += 1
|
433 |
+
return hidden_states
|
434 |
+
|
435 |
+
|
436 |
+
class CrossAttnDownBlock3D(nn.Module):
|
437 |
+
print_idx = 0
|
438 |
+
|
439 |
+
def __init__(
|
440 |
+
self,
|
441 |
+
in_channels: int,
|
442 |
+
out_channels: int,
|
443 |
+
temb_channels: int,
|
444 |
+
femb_channels: int,
|
445 |
+
dropout: float = 0.0,
|
446 |
+
num_layers: int = 1,
|
447 |
+
resnet_eps: float = 1e-6,
|
448 |
+
resnet_time_scale_shift: str = "default",
|
449 |
+
resnet_act_fn: str = "swish",
|
450 |
+
resnet_groups: int = 32,
|
451 |
+
resnet_pre_norm: bool = True,
|
452 |
+
attn_num_head_channels=1,
|
453 |
+
cross_attention_dim=1280,
|
454 |
+
output_scale_factor=1.0,
|
455 |
+
downsample_padding=1,
|
456 |
+
add_downsample=True,
|
457 |
+
dual_cross_attention=False,
|
458 |
+
use_linear_projection=False,
|
459 |
+
only_cross_attention=False,
|
460 |
+
upcast_attention=False,
|
461 |
+
temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
|
462 |
+
temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
|
463 |
+
need_spatial_position_emb: bool = False,
|
464 |
+
need_t2i_ip_adapter: bool = False,
|
465 |
+
ip_adapter_cross_attn: bool = False,
|
466 |
+
need_t2i_facein: bool = False,
|
467 |
+
need_t2i_ip_adapter_face: bool = False,
|
468 |
+
need_adain_temporal_cond: bool = False,
|
469 |
+
resnet_2d_skip_time_act: bool = False,
|
470 |
+
need_refer_emb: bool = False,
|
471 |
+
):
|
472 |
+
super().__init__()
|
473 |
+
resnets = []
|
474 |
+
attentions = []
|
475 |
+
temp_attentions = []
|
476 |
+
temp_convs = []
|
477 |
+
|
478 |
+
self.has_cross_attention = True
|
479 |
+
self.attn_num_head_channels = attn_num_head_channels
|
480 |
+
self.need_refer_emb = need_refer_emb
|
481 |
+
if need_refer_emb:
|
482 |
+
refer_emb_attns = []
|
483 |
+
for i in range(num_layers):
|
484 |
+
in_channels = in_channels if i == 0 else out_channels
|
485 |
+
resnets.append(
|
486 |
+
ResnetBlock2D(
|
487 |
+
in_channels=in_channels,
|
488 |
+
out_channels=out_channels,
|
489 |
+
temb_channels=temb_channels,
|
490 |
+
eps=resnet_eps,
|
491 |
+
groups=resnet_groups,
|
492 |
+
dropout=dropout,
|
493 |
+
time_embedding_norm=resnet_time_scale_shift,
|
494 |
+
non_linearity=resnet_act_fn,
|
495 |
+
output_scale_factor=output_scale_factor,
|
496 |
+
pre_norm=resnet_pre_norm,
|
497 |
+
skip_time_act=resnet_2d_skip_time_act,
|
498 |
+
)
|
499 |
+
)
|
500 |
+
if temporal_conv_block is not None:
|
501 |
+
temp_convs.append(
|
502 |
+
temporal_conv_block(
|
503 |
+
out_channels,
|
504 |
+
out_channels,
|
505 |
+
dropout=0.1,
|
506 |
+
femb_channels=femb_channels,
|
507 |
+
)
|
508 |
+
)
|
509 |
+
else:
|
510 |
+
temp_convs.append(None)
|
511 |
+
attentions.append(
|
512 |
+
Transformer2DModel(
|
513 |
+
attn_num_head_channels,
|
514 |
+
out_channels // attn_num_head_channels,
|
515 |
+
in_channels=out_channels,
|
516 |
+
num_layers=1,
|
517 |
+
cross_attention_dim=cross_attention_dim,
|
518 |
+
norm_num_groups=resnet_groups,
|
519 |
+
use_linear_projection=use_linear_projection,
|
520 |
+
only_cross_attention=only_cross_attention,
|
521 |
+
upcast_attention=upcast_attention,
|
522 |
+
cross_attn_temporal_cond=need_t2i_ip_adapter,
|
523 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
524 |
+
need_t2i_facein=need_t2i_facein,
|
525 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
526 |
+
)
|
527 |
+
)
|
528 |
+
if temporal_transformer is not None:
|
529 |
+
temp_attention = temporal_transformer(
|
530 |
+
attn_num_head_channels,
|
531 |
+
out_channels // attn_num_head_channels,
|
532 |
+
in_channels=out_channels,
|
533 |
+
num_layers=1,
|
534 |
+
femb_channels=femb_channels,
|
535 |
+
cross_attention_dim=cross_attention_dim,
|
536 |
+
norm_num_groups=resnet_groups,
|
537 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
538 |
+
)
|
539 |
+
else:
|
540 |
+
temp_attention = None
|
541 |
+
temp_attentions.append(temp_attention)
|
542 |
+
|
543 |
+
if need_refer_emb:
|
544 |
+
refer_emb_attns.append(
|
545 |
+
ReferEmbFuseAttention(
|
546 |
+
query_dim=out_channels,
|
547 |
+
heads=attn_num_head_channels,
|
548 |
+
dim_head=out_channels // attn_num_head_channels,
|
549 |
+
dropout=0,
|
550 |
+
bias=False,
|
551 |
+
cross_attention_dim=None,
|
552 |
+
upcast_attention=False,
|
553 |
+
)
|
554 |
+
)
|
555 |
+
|
556 |
+
self.resnets = nn.ModuleList(resnets)
|
557 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
558 |
+
self.attentions = nn.ModuleList(attentions)
|
559 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
560 |
+
|
561 |
+
if add_downsample:
|
562 |
+
self.downsamplers = nn.ModuleList(
|
563 |
+
[
|
564 |
+
Downsample2D(
|
565 |
+
out_channels,
|
566 |
+
use_conv=True,
|
567 |
+
out_channels=out_channels,
|
568 |
+
padding=downsample_padding,
|
569 |
+
name="op",
|
570 |
+
)
|
571 |
+
]
|
572 |
+
)
|
573 |
+
if need_refer_emb:
|
574 |
+
refer_emb_attns.append(
|
575 |
+
ReferEmbFuseAttention(
|
576 |
+
query_dim=out_channels,
|
577 |
+
heads=attn_num_head_channels,
|
578 |
+
dim_head=out_channels // attn_num_head_channels,
|
579 |
+
dropout=0,
|
580 |
+
bias=False,
|
581 |
+
cross_attention_dim=None,
|
582 |
+
upcast_attention=False,
|
583 |
+
)
|
584 |
+
)
|
585 |
+
else:
|
586 |
+
self.downsamplers = None
|
587 |
+
|
588 |
+
self.gradient_checkpointing = False
|
589 |
+
self.need_adain_temporal_cond = need_adain_temporal_cond
|
590 |
+
if need_refer_emb:
|
591 |
+
self.refer_emb_attns = nn.ModuleList(refer_emb_attns)
|
592 |
+
logger.debug(f"cross attn downblock 3d need_refer_emb, {self.need_refer_emb}")
|
593 |
+
|
594 |
+
def forward(
|
595 |
+
self,
|
596 |
+
hidden_states: torch.FloatTensor,
|
597 |
+
temb: Optional[torch.FloatTensor] = None,
|
598 |
+
femb: Optional[torch.FloatTensor] = None,
|
599 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
600 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
601 |
+
num_frames: int = 1,
|
602 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
603 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
604 |
+
sample_index: torch.LongTensor = None,
|
605 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
606 |
+
spatial_position_emb: torch.Tensor = None,
|
607 |
+
refer_embs: Optional[List[torch.Tensor]] = None,
|
608 |
+
refer_self_attn_emb: List[torch.Tensor] = None,
|
609 |
+
refer_self_attn_emb_mode: Literal["read", "write"] = "read",
|
610 |
+
):
|
611 |
+
# TODO(Patrick, William) - attention mask is not used
|
612 |
+
output_states = ()
|
613 |
+
for i_downblock, (resnet, temp_conv, attn, temp_attn) in enumerate(
|
614 |
+
zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions)
|
615 |
+
):
|
616 |
+
# print("crossattndownblock3d, attn,", type(attn), cross_attention_kwargs)
|
617 |
+
if self.training and self.gradient_checkpointing:
|
618 |
+
if self.print_idx == 0:
|
619 |
+
logger.debug(
|
620 |
+
f"self.training and self.gradient_checkpointing={self.training and self.gradient_checkpointing}"
|
621 |
+
)
|
622 |
+
|
623 |
+
def create_custom_forward(module, return_dict=None):
|
624 |
+
def custom_forward(*inputs):
|
625 |
+
if return_dict is not None:
|
626 |
+
return module(*inputs, return_dict=return_dict)
|
627 |
+
else:
|
628 |
+
return module(*inputs)
|
629 |
+
|
630 |
+
return custom_forward
|
631 |
+
|
632 |
+
ckpt_kwargs: Dict[str, Any] = (
|
633 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
634 |
+
)
|
635 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
636 |
+
create_custom_forward(resnet),
|
637 |
+
hidden_states,
|
638 |
+
temb,
|
639 |
+
**ckpt_kwargs,
|
640 |
+
)
|
641 |
+
if self.print_idx == 0:
|
642 |
+
logger.debug(f"unet3d after resnet {hidden_states.mean()}")
|
643 |
+
if temp_conv is not None:
|
644 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
645 |
+
create_custom_forward(temp_conv),
|
646 |
+
hidden_states,
|
647 |
+
num_frames,
|
648 |
+
sample_index,
|
649 |
+
vision_conditon_frames_sample_index,
|
650 |
+
femb,
|
651 |
+
**ckpt_kwargs,
|
652 |
+
)
|
653 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
654 |
+
create_custom_forward(attn, return_dict=False),
|
655 |
+
hidden_states,
|
656 |
+
encoder_hidden_states,
|
657 |
+
None, # timestep
|
658 |
+
None, # added_cond_kwargs
|
659 |
+
None, # class_labels
|
660 |
+
cross_attention_kwargs,
|
661 |
+
attention_mask,
|
662 |
+
encoder_attention_mask,
|
663 |
+
refer_self_attn_emb,
|
664 |
+
refer_self_attn_emb_mode,
|
665 |
+
**ckpt_kwargs,
|
666 |
+
)[0]
|
667 |
+
if temp_attn is not None:
|
668 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
669 |
+
create_custom_forward(temp_attn, return_dict=False),
|
670 |
+
hidden_states,
|
671 |
+
femb,
|
672 |
+
# None, # encoder_hidden_states,
|
673 |
+
encoder_hidden_states,
|
674 |
+
None, # timestep
|
675 |
+
None, # class_labels
|
676 |
+
num_frames,
|
677 |
+
cross_attention_kwargs,
|
678 |
+
sample_index,
|
679 |
+
vision_conditon_frames_sample_index,
|
680 |
+
spatial_position_emb,
|
681 |
+
**ckpt_kwargs,
|
682 |
+
)[0]
|
683 |
+
else:
|
684 |
+
hidden_states = resnet(hidden_states, temb)
|
685 |
+
if self.print_idx == 0:
|
686 |
+
logger.debug(f"unet3d after resnet {hidden_states.mean()}")
|
687 |
+
if temp_conv is not None:
|
688 |
+
hidden_states = temp_conv(
|
689 |
+
hidden_states,
|
690 |
+
femb=femb,
|
691 |
+
num_frames=num_frames,
|
692 |
+
sample_index=sample_index,
|
693 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
694 |
+
)
|
695 |
+
hidden_states = attn(
|
696 |
+
hidden_states,
|
697 |
+
encoder_hidden_states=encoder_hidden_states,
|
698 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
699 |
+
self_attn_block_embs=refer_self_attn_emb,
|
700 |
+
self_attn_block_embs_mode=refer_self_attn_emb_mode,
|
701 |
+
).sample
|
702 |
+
if temp_attn is not None:
|
703 |
+
hidden_states = temp_attn(
|
704 |
+
hidden_states,
|
705 |
+
femb=femb,
|
706 |
+
num_frames=num_frames,
|
707 |
+
encoder_hidden_states=encoder_hidden_states,
|
708 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
709 |
+
sample_index=sample_index,
|
710 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
711 |
+
spatial_position_emb=spatial_position_emb,
|
712 |
+
).sample
|
713 |
+
if (
|
714 |
+
self.need_adain_temporal_cond
|
715 |
+
and num_frames > 1
|
716 |
+
and sample_index is not None
|
717 |
+
):
|
718 |
+
if self.print_idx == 0:
|
719 |
+
logger.debug(f"adain to vision_condition")
|
720 |
+
hidden_states = batch_adain_conditioned_tensor(
|
721 |
+
hidden_states,
|
722 |
+
num_frames=num_frames,
|
723 |
+
need_style_fidelity=False,
|
724 |
+
src_index=sample_index,
|
725 |
+
dst_index=vision_conditon_frames_sample_index,
|
726 |
+
)
|
727 |
+
# 使用 attn 的方式 来融合 down_block_refer_emb
|
728 |
+
if self.print_idx == 0:
|
729 |
+
logger.debug(
|
730 |
+
f"downblock, {i_downblock}, self.need_refer_emb={self.need_refer_emb}"
|
731 |
+
)
|
732 |
+
if self.need_refer_emb and refer_embs is not None:
|
733 |
+
if self.print_idx == 0:
|
734 |
+
logger.debug(
|
735 |
+
f"{i_downblock}, self.refer_emb_attns {refer_embs[i_downblock].shape}"
|
736 |
+
)
|
737 |
+
hidden_states = self.refer_emb_attns[i_downblock](
|
738 |
+
hidden_states, refer_embs[i_downblock], num_frames=num_frames
|
739 |
+
)
|
740 |
+
else:
|
741 |
+
if self.print_idx == 0:
|
742 |
+
logger.debug(f"crossattndownblock refer_emb_attns, no this step")
|
743 |
+
output_states += (hidden_states,)
|
744 |
+
|
745 |
+
if self.downsamplers is not None:
|
746 |
+
for downsampler in self.downsamplers:
|
747 |
+
hidden_states = downsampler(hidden_states)
|
748 |
+
if (
|
749 |
+
self.need_adain_temporal_cond
|
750 |
+
and num_frames > 1
|
751 |
+
and sample_index is not None
|
752 |
+
):
|
753 |
+
if self.print_idx == 0:
|
754 |
+
logger.debug(f"adain to vision_condition")
|
755 |
+
hidden_states = batch_adain_conditioned_tensor(
|
756 |
+
hidden_states,
|
757 |
+
num_frames=num_frames,
|
758 |
+
need_style_fidelity=False,
|
759 |
+
src_index=sample_index,
|
760 |
+
dst_index=vision_conditon_frames_sample_index,
|
761 |
+
)
|
762 |
+
# 使用 attn 的方式 来融合 down_block_refer_emb
|
763 |
+
# TODO: adain和 refer_emb的顺序
|
764 |
+
# TODO:adain 首帧特征还是refer_emb的
|
765 |
+
if self.need_refer_emb and refer_embs is not None:
|
766 |
+
i_downblock += 1
|
767 |
+
hidden_states = self.refer_emb_attns[i_downblock](
|
768 |
+
hidden_states, refer_embs[i_downblock], num_frames=num_frames
|
769 |
+
)
|
770 |
+
output_states += (hidden_states,)
|
771 |
+
self.print_idx += 1
|
772 |
+
return hidden_states, output_states
|
773 |
+
|
774 |
+
|
775 |
+
class DownBlock3D(nn.Module):
|
776 |
+
print_idx = 0
|
777 |
+
|
778 |
+
def __init__(
|
779 |
+
self,
|
780 |
+
in_channels: int,
|
781 |
+
out_channels: int,
|
782 |
+
temb_channels: int,
|
783 |
+
femb_channels: int,
|
784 |
+
dropout: float = 0.0,
|
785 |
+
num_layers: int = 1,
|
786 |
+
resnet_eps: float = 1e-6,
|
787 |
+
resnet_time_scale_shift: str = "default",
|
788 |
+
resnet_act_fn: str = "swish",
|
789 |
+
resnet_groups: int = 32,
|
790 |
+
resnet_pre_norm: bool = True,
|
791 |
+
output_scale_factor=1.0,
|
792 |
+
add_downsample=True,
|
793 |
+
downsample_padding=1,
|
794 |
+
temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
|
795 |
+
need_adain_temporal_cond: bool = False,
|
796 |
+
resnet_2d_skip_time_act: bool = False,
|
797 |
+
need_refer_emb: bool = False,
|
798 |
+
attn_num_head_channels: int = 1,
|
799 |
+
):
|
800 |
+
super().__init__()
|
801 |
+
resnets = []
|
802 |
+
temp_convs = []
|
803 |
+
self.need_refer_emb = need_refer_emb
|
804 |
+
if need_refer_emb:
|
805 |
+
refer_emb_attns = []
|
806 |
+
self.attn_num_head_channels = attn_num_head_channels
|
807 |
+
|
808 |
+
for i in range(num_layers):
|
809 |
+
in_channels = in_channels if i == 0 else out_channels
|
810 |
+
resnets.append(
|
811 |
+
ResnetBlock2D(
|
812 |
+
in_channels=in_channels,
|
813 |
+
out_channels=out_channels,
|
814 |
+
temb_channels=temb_channels,
|
815 |
+
eps=resnet_eps,
|
816 |
+
groups=resnet_groups,
|
817 |
+
dropout=dropout,
|
818 |
+
time_embedding_norm=resnet_time_scale_shift,
|
819 |
+
non_linearity=resnet_act_fn,
|
820 |
+
output_scale_factor=output_scale_factor,
|
821 |
+
pre_norm=resnet_pre_norm,
|
822 |
+
skip_time_act=resnet_2d_skip_time_act,
|
823 |
+
)
|
824 |
+
)
|
825 |
+
if temporal_conv_block is not None:
|
826 |
+
temp_convs.append(
|
827 |
+
temporal_conv_block(
|
828 |
+
out_channels,
|
829 |
+
out_channels,
|
830 |
+
dropout=0.1,
|
831 |
+
femb_channels=femb_channels,
|
832 |
+
)
|
833 |
+
)
|
834 |
+
else:
|
835 |
+
temp_convs.append(None)
|
836 |
+
if need_refer_emb:
|
837 |
+
refer_emb_attns.append(
|
838 |
+
ReferEmbFuseAttention(
|
839 |
+
query_dim=out_channels,
|
840 |
+
heads=attn_num_head_channels,
|
841 |
+
dim_head=out_channels // attn_num_head_channels,
|
842 |
+
dropout=0,
|
843 |
+
bias=False,
|
844 |
+
cross_attention_dim=None,
|
845 |
+
upcast_attention=False,
|
846 |
+
)
|
847 |
+
)
|
848 |
+
|
849 |
+
self.resnets = nn.ModuleList(resnets)
|
850 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
851 |
+
|
852 |
+
if add_downsample:
|
853 |
+
self.downsamplers = nn.ModuleList(
|
854 |
+
[
|
855 |
+
Downsample2D(
|
856 |
+
out_channels,
|
857 |
+
use_conv=True,
|
858 |
+
out_channels=out_channels,
|
859 |
+
padding=downsample_padding,
|
860 |
+
name="op",
|
861 |
+
)
|
862 |
+
]
|
863 |
+
)
|
864 |
+
if need_refer_emb:
|
865 |
+
refer_emb_attns.append(
|
866 |
+
ReferEmbFuseAttention(
|
867 |
+
query_dim=out_channels,
|
868 |
+
heads=attn_num_head_channels,
|
869 |
+
dim_head=out_channels // attn_num_head_channels,
|
870 |
+
dropout=0,
|
871 |
+
bias=False,
|
872 |
+
cross_attention_dim=None,
|
873 |
+
upcast_attention=False,
|
874 |
+
)
|
875 |
+
)
|
876 |
+
else:
|
877 |
+
self.downsamplers = None
|
878 |
+
|
879 |
+
self.gradient_checkpointing = False
|
880 |
+
self.need_adain_temporal_cond = need_adain_temporal_cond
|
881 |
+
if need_refer_emb:
|
882 |
+
self.refer_emb_attns = nn.ModuleList(refer_emb_attns)
|
883 |
+
|
884 |
+
def forward(
|
885 |
+
self,
|
886 |
+
hidden_states,
|
887 |
+
temb=None,
|
888 |
+
num_frames=1,
|
889 |
+
sample_index: torch.LongTensor = None,
|
890 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
891 |
+
spatial_position_emb: torch.Tensor = None,
|
892 |
+
femb=None,
|
893 |
+
refer_embs: Optional[Tuple[torch.Tensor]] = None,
|
894 |
+
refer_self_attn_emb: List[torch.Tensor] = None,
|
895 |
+
refer_self_attn_emb_mode: Literal["read", "write"] = "read",
|
896 |
+
):
|
897 |
+
output_states = ()
|
898 |
+
|
899 |
+
for i_downblock, (resnet, temp_conv) in enumerate(
|
900 |
+
zip(self.resnets, self.temp_convs)
|
901 |
+
):
|
902 |
+
if self.training and self.gradient_checkpointing:
|
903 |
+
|
904 |
+
def create_custom_forward(module):
|
905 |
+
def custom_forward(*inputs):
|
906 |
+
return module(*inputs)
|
907 |
+
|
908 |
+
return custom_forward
|
909 |
+
|
910 |
+
ckpt_kwargs: Dict[str, Any] = (
|
911 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
912 |
+
)
|
913 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
914 |
+
create_custom_forward(resnet),
|
915 |
+
hidden_states,
|
916 |
+
temb,
|
917 |
+
**ckpt_kwargs,
|
918 |
+
)
|
919 |
+
if temp_conv is not None:
|
920 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
921 |
+
create_custom_forward(temp_conv),
|
922 |
+
hidden_states,
|
923 |
+
num_frames,
|
924 |
+
sample_index,
|
925 |
+
vision_conditon_frames_sample_index,
|
926 |
+
femb,
|
927 |
+
**ckpt_kwargs,
|
928 |
+
)
|
929 |
+
else:
|
930 |
+
hidden_states = resnet(hidden_states, temb)
|
931 |
+
if temp_conv is not None:
|
932 |
+
hidden_states = temp_conv(
|
933 |
+
hidden_states,
|
934 |
+
femb=femb,
|
935 |
+
num_frames=num_frames,
|
936 |
+
sample_index=sample_index,
|
937 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
938 |
+
)
|
939 |
+
if (
|
940 |
+
self.need_adain_temporal_cond
|
941 |
+
and num_frames > 1
|
942 |
+
and sample_index is not None
|
943 |
+
):
|
944 |
+
if self.print_idx == 0:
|
945 |
+
logger.debug(f"adain to vision_condition")
|
946 |
+
hidden_states = batch_adain_conditioned_tensor(
|
947 |
+
hidden_states,
|
948 |
+
num_frames=num_frames,
|
949 |
+
need_style_fidelity=False,
|
950 |
+
src_index=sample_index,
|
951 |
+
dst_index=vision_conditon_frames_sample_index,
|
952 |
+
)
|
953 |
+
if self.need_refer_emb and refer_embs is not None:
|
954 |
+
hidden_states = self.refer_emb_attns[i_downblock](
|
955 |
+
hidden_states, refer_embs[i_downblock], num_frames=num_frames
|
956 |
+
)
|
957 |
+
output_states += (hidden_states,)
|
958 |
+
|
959 |
+
if self.downsamplers is not None:
|
960 |
+
for downsampler in self.downsamplers:
|
961 |
+
hidden_states = downsampler(hidden_states)
|
962 |
+
if (
|
963 |
+
self.need_adain_temporal_cond
|
964 |
+
and num_frames > 1
|
965 |
+
and sample_index is not None
|
966 |
+
):
|
967 |
+
if self.print_idx == 0:
|
968 |
+
logger.debug(f"adain to vision_condition")
|
969 |
+
hidden_states = batch_adain_conditioned_tensor(
|
970 |
+
hidden_states,
|
971 |
+
num_frames=num_frames,
|
972 |
+
need_style_fidelity=False,
|
973 |
+
src_index=sample_index,
|
974 |
+
dst_index=vision_conditon_frames_sample_index,
|
975 |
+
)
|
976 |
+
if self.need_refer_emb and refer_embs is not None:
|
977 |
+
i_downblock += 1
|
978 |
+
hidden_states = self.refer_emb_attns[i_downblock](
|
979 |
+
hidden_states, refer_embs[i_downblock], num_frames=num_frames
|
980 |
+
)
|
981 |
+
output_states += (hidden_states,)
|
982 |
+
self.print_idx += 1
|
983 |
+
return hidden_states, output_states
|
984 |
+
|
985 |
+
|
986 |
+
class CrossAttnUpBlock3D(nn.Module):
|
987 |
+
print_idx = 0
|
988 |
+
|
989 |
+
def __init__(
|
990 |
+
self,
|
991 |
+
in_channels: int,
|
992 |
+
out_channels: int,
|
993 |
+
prev_output_channel: int,
|
994 |
+
temb_channels: int,
|
995 |
+
femb_channels: int,
|
996 |
+
dropout: float = 0.0,
|
997 |
+
num_layers: int = 1,
|
998 |
+
resnet_eps: float = 1e-6,
|
999 |
+
resnet_time_scale_shift: str = "default",
|
1000 |
+
resnet_act_fn: str = "swish",
|
1001 |
+
resnet_groups: int = 32,
|
1002 |
+
resnet_pre_norm: bool = True,
|
1003 |
+
attn_num_head_channels=1,
|
1004 |
+
cross_attention_dim=1280,
|
1005 |
+
output_scale_factor=1.0,
|
1006 |
+
add_upsample=True,
|
1007 |
+
dual_cross_attention=False,
|
1008 |
+
use_linear_projection=False,
|
1009 |
+
only_cross_attention=False,
|
1010 |
+
upcast_attention=False,
|
1011 |
+
temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
|
1012 |
+
temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel,
|
1013 |
+
need_spatial_position_emb: bool = False,
|
1014 |
+
need_t2i_ip_adapter: bool = False,
|
1015 |
+
ip_adapter_cross_attn: bool = False,
|
1016 |
+
need_t2i_facein: bool = False,
|
1017 |
+
need_t2i_ip_adapter_face: bool = False,
|
1018 |
+
need_adain_temporal_cond: bool = False,
|
1019 |
+
resnet_2d_skip_time_act: bool = False,
|
1020 |
+
):
|
1021 |
+
super().__init__()
|
1022 |
+
resnets = []
|
1023 |
+
temp_convs = []
|
1024 |
+
attentions = []
|
1025 |
+
temp_attentions = []
|
1026 |
+
|
1027 |
+
self.has_cross_attention = True
|
1028 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1029 |
+
|
1030 |
+
for i in range(num_layers):
|
1031 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1032 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1033 |
+
|
1034 |
+
resnets.append(
|
1035 |
+
ResnetBlock2D(
|
1036 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1037 |
+
out_channels=out_channels,
|
1038 |
+
temb_channels=temb_channels,
|
1039 |
+
eps=resnet_eps,
|
1040 |
+
groups=resnet_groups,
|
1041 |
+
dropout=dropout,
|
1042 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1043 |
+
non_linearity=resnet_act_fn,
|
1044 |
+
output_scale_factor=output_scale_factor,
|
1045 |
+
pre_norm=resnet_pre_norm,
|
1046 |
+
skip_time_act=resnet_2d_skip_time_act,
|
1047 |
+
)
|
1048 |
+
)
|
1049 |
+
if temporal_conv_block is not None:
|
1050 |
+
temp_convs.append(
|
1051 |
+
temporal_conv_block(
|
1052 |
+
out_channels,
|
1053 |
+
out_channels,
|
1054 |
+
dropout=0.1,
|
1055 |
+
femb_channels=femb_channels,
|
1056 |
+
)
|
1057 |
+
)
|
1058 |
+
else:
|
1059 |
+
temp_convs.append(None)
|
1060 |
+
attentions.append(
|
1061 |
+
Transformer2DModel(
|
1062 |
+
attn_num_head_channels,
|
1063 |
+
out_channels // attn_num_head_channels,
|
1064 |
+
in_channels=out_channels,
|
1065 |
+
num_layers=1,
|
1066 |
+
cross_attention_dim=cross_attention_dim,
|
1067 |
+
norm_num_groups=resnet_groups,
|
1068 |
+
use_linear_projection=use_linear_projection,
|
1069 |
+
only_cross_attention=only_cross_attention,
|
1070 |
+
upcast_attention=upcast_attention,
|
1071 |
+
cross_attn_temporal_cond=need_t2i_ip_adapter,
|
1072 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
1073 |
+
need_t2i_facein=need_t2i_facein,
|
1074 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
1075 |
+
)
|
1076 |
+
)
|
1077 |
+
if temporal_transformer is not None:
|
1078 |
+
temp_attention = temporal_transformer(
|
1079 |
+
attn_num_head_channels,
|
1080 |
+
out_channels // attn_num_head_channels,
|
1081 |
+
in_channels=out_channels,
|
1082 |
+
num_layers=1,
|
1083 |
+
femb_channels=femb_channels,
|
1084 |
+
cross_attention_dim=cross_attention_dim,
|
1085 |
+
norm_num_groups=resnet_groups,
|
1086 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
1087 |
+
)
|
1088 |
+
else:
|
1089 |
+
temp_attention = None
|
1090 |
+
temp_attentions.append(temp_attention)
|
1091 |
+
self.resnets = nn.ModuleList(resnets)
|
1092 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
1093 |
+
self.attentions = nn.ModuleList(attentions)
|
1094 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
1095 |
+
|
1096 |
+
if add_upsample:
|
1097 |
+
self.upsamplers = nn.ModuleList(
|
1098 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
|
1099 |
+
)
|
1100 |
+
else:
|
1101 |
+
self.upsamplers = None
|
1102 |
+
|
1103 |
+
self.gradient_checkpointing = False
|
1104 |
+
self.need_adain_temporal_cond = need_adain_temporal_cond
|
1105 |
+
|
1106 |
+
def forward(
|
1107 |
+
self,
|
1108 |
+
hidden_states: torch.FloatTensor,
|
1109 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1110 |
+
temb: Optional[torch.FloatTensor] = None,
|
1111 |
+
femb: Optional[torch.FloatTensor] = None,
|
1112 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1113 |
+
num_frames: int = 1,
|
1114 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1115 |
+
upsample_size: Optional[int] = None,
|
1116 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1117 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1118 |
+
sample_index: torch.LongTensor = None,
|
1119 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
1120 |
+
spatial_position_emb: torch.Tensor = None,
|
1121 |
+
refer_self_attn_emb: List[torch.Tensor] = None,
|
1122 |
+
refer_self_attn_emb_mode: Literal["read", "write"] = "read",
|
1123 |
+
):
|
1124 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
1125 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
1126 |
+
):
|
1127 |
+
# pop res hidden states
|
1128 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1129 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1130 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1131 |
+
if self.training and self.gradient_checkpointing:
|
1132 |
+
|
1133 |
+
def create_custom_forward(module, return_dict=None):
|
1134 |
+
def custom_forward(*inputs):
|
1135 |
+
if return_dict is not None:
|
1136 |
+
return module(*inputs, return_dict=return_dict)
|
1137 |
+
else:
|
1138 |
+
return module(*inputs)
|
1139 |
+
|
1140 |
+
return custom_forward
|
1141 |
+
|
1142 |
+
ckpt_kwargs: Dict[str, Any] = (
|
1143 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1144 |
+
)
|
1145 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1146 |
+
create_custom_forward(resnet),
|
1147 |
+
hidden_states,
|
1148 |
+
temb,
|
1149 |
+
**ckpt_kwargs,
|
1150 |
+
)
|
1151 |
+
if temp_conv is not None:
|
1152 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1153 |
+
create_custom_forward(temp_conv),
|
1154 |
+
hidden_states,
|
1155 |
+
num_frames,
|
1156 |
+
sample_index,
|
1157 |
+
vision_conditon_frames_sample_index,
|
1158 |
+
femb,
|
1159 |
+
**ckpt_kwargs,
|
1160 |
+
)
|
1161 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1162 |
+
create_custom_forward(attn, return_dict=False),
|
1163 |
+
hidden_states,
|
1164 |
+
encoder_hidden_states,
|
1165 |
+
None, # timestep
|
1166 |
+
None, # added_cond_kwargs
|
1167 |
+
None, # class_labels
|
1168 |
+
cross_attention_kwargs,
|
1169 |
+
attention_mask,
|
1170 |
+
encoder_attention_mask,
|
1171 |
+
refer_self_attn_emb,
|
1172 |
+
refer_self_attn_emb_mode,
|
1173 |
+
**ckpt_kwargs,
|
1174 |
+
)[0]
|
1175 |
+
if temp_attn is not None:
|
1176 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1177 |
+
create_custom_forward(temp_attn, return_dict=False),
|
1178 |
+
hidden_states,
|
1179 |
+
femb,
|
1180 |
+
# None, # encoder_hidden_states,
|
1181 |
+
encoder_hidden_states,
|
1182 |
+
None, # timestep
|
1183 |
+
None, # class_labels
|
1184 |
+
num_frames,
|
1185 |
+
cross_attention_kwargs,
|
1186 |
+
sample_index,
|
1187 |
+
vision_conditon_frames_sample_index,
|
1188 |
+
spatial_position_emb,
|
1189 |
+
**ckpt_kwargs,
|
1190 |
+
)[0]
|
1191 |
+
else:
|
1192 |
+
hidden_states = resnet(hidden_states, temb)
|
1193 |
+
if temp_conv is not None:
|
1194 |
+
hidden_states = temp_conv(
|
1195 |
+
hidden_states,
|
1196 |
+
num_frames=num_frames,
|
1197 |
+
femb=femb,
|
1198 |
+
sample_index=sample_index,
|
1199 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
1200 |
+
)
|
1201 |
+
hidden_states = attn(
|
1202 |
+
hidden_states,
|
1203 |
+
encoder_hidden_states=encoder_hidden_states,
|
1204 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1205 |
+
self_attn_block_embs=refer_self_attn_emb,
|
1206 |
+
self_attn_block_embs_mode=refer_self_attn_emb_mode,
|
1207 |
+
).sample
|
1208 |
+
if temp_attn is not None:
|
1209 |
+
hidden_states = temp_attn(
|
1210 |
+
hidden_states,
|
1211 |
+
femb=femb,
|
1212 |
+
num_frames=num_frames,
|
1213 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1214 |
+
encoder_hidden_states=encoder_hidden_states,
|
1215 |
+
sample_index=sample_index,
|
1216 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
1217 |
+
spatial_position_emb=spatial_position_emb,
|
1218 |
+
).sample
|
1219 |
+
if (
|
1220 |
+
self.need_adain_temporal_cond
|
1221 |
+
and num_frames > 1
|
1222 |
+
and sample_index is not None
|
1223 |
+
):
|
1224 |
+
if self.print_idx == 0:
|
1225 |
+
logger.debug(f"adain to vision_condition")
|
1226 |
+
hidden_states = batch_adain_conditioned_tensor(
|
1227 |
+
hidden_states,
|
1228 |
+
num_frames=num_frames,
|
1229 |
+
need_style_fidelity=False,
|
1230 |
+
src_index=sample_index,
|
1231 |
+
dst_index=vision_conditon_frames_sample_index,
|
1232 |
+
)
|
1233 |
+
if self.upsamplers is not None:
|
1234 |
+
for upsampler in self.upsamplers:
|
1235 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1236 |
+
if (
|
1237 |
+
self.need_adain_temporal_cond
|
1238 |
+
and num_frames > 1
|
1239 |
+
and sample_index is not None
|
1240 |
+
):
|
1241 |
+
if self.print_idx == 0:
|
1242 |
+
logger.debug(f"adain to vision_condition")
|
1243 |
+
hidden_states = batch_adain_conditioned_tensor(
|
1244 |
+
hidden_states,
|
1245 |
+
num_frames=num_frames,
|
1246 |
+
need_style_fidelity=False,
|
1247 |
+
src_index=sample_index,
|
1248 |
+
dst_index=vision_conditon_frames_sample_index,
|
1249 |
+
)
|
1250 |
+
self.print_idx += 1
|
1251 |
+
return hidden_states
|
1252 |
+
|
1253 |
+
|
1254 |
+
class UpBlock3D(nn.Module):
|
1255 |
+
print_idx = 0
|
1256 |
+
|
1257 |
+
def __init__(
|
1258 |
+
self,
|
1259 |
+
in_channels: int,
|
1260 |
+
prev_output_channel: int,
|
1261 |
+
out_channels: int,
|
1262 |
+
temb_channels: int,
|
1263 |
+
femb_channels: int,
|
1264 |
+
dropout: float = 0.0,
|
1265 |
+
num_layers: int = 1,
|
1266 |
+
resnet_eps: float = 1e-6,
|
1267 |
+
resnet_time_scale_shift: str = "default",
|
1268 |
+
resnet_act_fn: str = "swish",
|
1269 |
+
resnet_groups: int = 32,
|
1270 |
+
resnet_pre_norm: bool = True,
|
1271 |
+
output_scale_factor=1.0,
|
1272 |
+
add_upsample=True,
|
1273 |
+
temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer,
|
1274 |
+
need_adain_temporal_cond: bool = False,
|
1275 |
+
resnet_2d_skip_time_act: bool = False,
|
1276 |
+
):
|
1277 |
+
super().__init__()
|
1278 |
+
resnets = []
|
1279 |
+
temp_convs = []
|
1280 |
+
|
1281 |
+
for i in range(num_layers):
|
1282 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1283 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1284 |
+
|
1285 |
+
resnets.append(
|
1286 |
+
ResnetBlock2D(
|
1287 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1288 |
+
out_channels=out_channels,
|
1289 |
+
temb_channels=temb_channels,
|
1290 |
+
eps=resnet_eps,
|
1291 |
+
groups=resnet_groups,
|
1292 |
+
dropout=dropout,
|
1293 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1294 |
+
non_linearity=resnet_act_fn,
|
1295 |
+
output_scale_factor=output_scale_factor,
|
1296 |
+
pre_norm=resnet_pre_norm,
|
1297 |
+
skip_time_act=resnet_2d_skip_time_act,
|
1298 |
+
)
|
1299 |
+
)
|
1300 |
+
if temporal_conv_block is not None:
|
1301 |
+
temp_convs.append(
|
1302 |
+
temporal_conv_block(
|
1303 |
+
out_channels,
|
1304 |
+
out_channels,
|
1305 |
+
dropout=0.1,
|
1306 |
+
femb_channels=femb_channels,
|
1307 |
+
)
|
1308 |
+
)
|
1309 |
+
else:
|
1310 |
+
temp_convs.append(None)
|
1311 |
+
self.resnets = nn.ModuleList(resnets)
|
1312 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
1313 |
+
|
1314 |
+
if add_upsample:
|
1315 |
+
self.upsamplers = nn.ModuleList(
|
1316 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
|
1317 |
+
)
|
1318 |
+
else:
|
1319 |
+
self.upsamplers = None
|
1320 |
+
|
1321 |
+
self.gradient_checkpointing = False
|
1322 |
+
self.need_adain_temporal_cond = need_adain_temporal_cond
|
1323 |
+
|
1324 |
+
def forward(
|
1325 |
+
self,
|
1326 |
+
hidden_states,
|
1327 |
+
res_hidden_states_tuple,
|
1328 |
+
temb=None,
|
1329 |
+
upsample_size=None,
|
1330 |
+
num_frames=1,
|
1331 |
+
sample_index: torch.LongTensor = None,
|
1332 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
1333 |
+
spatial_position_emb: torch.Tensor = None,
|
1334 |
+
femb=None,
|
1335 |
+
refer_self_attn_emb: List[torch.Tensor] = None,
|
1336 |
+
refer_self_attn_emb_mode: Literal["read", "write"] = "read",
|
1337 |
+
):
|
1338 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
1339 |
+
# pop res hidden states
|
1340 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1341 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1342 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1343 |
+
|
1344 |
+
if self.training and self.gradient_checkpointing:
|
1345 |
+
|
1346 |
+
def create_custom_forward(module):
|
1347 |
+
def custom_forward(*inputs):
|
1348 |
+
return module(*inputs)
|
1349 |
+
|
1350 |
+
return custom_forward
|
1351 |
+
|
1352 |
+
ckpt_kwargs: Dict[str, Any] = (
|
1353 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1354 |
+
)
|
1355 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1356 |
+
create_custom_forward(resnet),
|
1357 |
+
hidden_states,
|
1358 |
+
temb,
|
1359 |
+
**ckpt_kwargs,
|
1360 |
+
)
|
1361 |
+
if temp_conv is not None:
|
1362 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1363 |
+
create_custom_forward(temp_conv),
|
1364 |
+
hidden_states,
|
1365 |
+
num_frames,
|
1366 |
+
sample_index,
|
1367 |
+
vision_conditon_frames_sample_index,
|
1368 |
+
femb,
|
1369 |
+
**ckpt_kwargs,
|
1370 |
+
)
|
1371 |
+
else:
|
1372 |
+
hidden_states = resnet(hidden_states, temb)
|
1373 |
+
if temp_conv is not None:
|
1374 |
+
hidden_states = temp_conv(
|
1375 |
+
hidden_states,
|
1376 |
+
num_frames=num_frames,
|
1377 |
+
femb=femb,
|
1378 |
+
sample_index=sample_index,
|
1379 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
1380 |
+
)
|
1381 |
+
if (
|
1382 |
+
self.need_adain_temporal_cond
|
1383 |
+
and num_frames > 1
|
1384 |
+
and sample_index is not None
|
1385 |
+
):
|
1386 |
+
if self.print_idx == 0:
|
1387 |
+
logger.debug(f"adain to vision_condition")
|
1388 |
+
hidden_states = batch_adain_conditioned_tensor(
|
1389 |
+
hidden_states,
|
1390 |
+
num_frames=num_frames,
|
1391 |
+
need_style_fidelity=False,
|
1392 |
+
src_index=sample_index,
|
1393 |
+
dst_index=vision_conditon_frames_sample_index,
|
1394 |
+
)
|
1395 |
+
if self.upsamplers is not None:
|
1396 |
+
for upsampler in self.upsamplers:
|
1397 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1398 |
+
if (
|
1399 |
+
self.need_adain_temporal_cond
|
1400 |
+
and num_frames > 1
|
1401 |
+
and sample_index is not None
|
1402 |
+
):
|
1403 |
+
if self.print_idx == 0:
|
1404 |
+
logger.debug(f"adain to vision_condition")
|
1405 |
+
hidden_states = batch_adain_conditioned_tensor(
|
1406 |
+
hidden_states,
|
1407 |
+
num_frames=num_frames,
|
1408 |
+
need_style_fidelity=False,
|
1409 |
+
src_index=sample_index,
|
1410 |
+
dst_index=vision_conditon_frames_sample_index,
|
1411 |
+
)
|
1412 |
+
self.print_idx += 1
|
1413 |
+
return hidden_states
|
musev/models/unet_3d_condition.py
ADDED
@@ -0,0 +1,1740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2 |
+
# Copyright 2023 The ModelScope Team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/unet_3d_condition.py
|
17 |
+
|
18 |
+
# 1. 增加了from_pretrained,将模型从2D blocks改为3D blocks
|
19 |
+
# 1. add from_pretrained, change model from 2D blocks to 3D blocks
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from dataclasses import dataclass
|
23 |
+
import inspect
|
24 |
+
from pprint import pprint, pformat
|
25 |
+
from typing import Any, Dict, List, Optional, Tuple, Union, Literal
|
26 |
+
import os
|
27 |
+
import logging
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import torch.nn as nn
|
31 |
+
import torch.utils.checkpoint
|
32 |
+
from einops import rearrange, repeat
|
33 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
34 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
35 |
+
from diffusers.utils import BaseOutput
|
36 |
+
|
37 |
+
# from diffusers.utils import logging
|
38 |
+
from diffusers.models.embeddings import (
|
39 |
+
TimestepEmbedding,
|
40 |
+
Timesteps,
|
41 |
+
)
|
42 |
+
from diffusers.models.modeling_utils import ModelMixin, load_state_dict
|
43 |
+
from diffusers import __version__
|
44 |
+
from diffusers.utils import (
|
45 |
+
CONFIG_NAME,
|
46 |
+
DIFFUSERS_CACHE,
|
47 |
+
FLAX_WEIGHTS_NAME,
|
48 |
+
HF_HUB_OFFLINE,
|
49 |
+
SAFETENSORS_WEIGHTS_NAME,
|
50 |
+
WEIGHTS_NAME,
|
51 |
+
_add_variant,
|
52 |
+
_get_model_file,
|
53 |
+
is_accelerate_available,
|
54 |
+
is_torch_version,
|
55 |
+
)
|
56 |
+
from diffusers.utils.import_utils import _safetensors_available
|
57 |
+
from diffusers.models.unet_3d_condition import (
|
58 |
+
UNet3DConditionOutput,
|
59 |
+
UNet3DConditionModel as DiffusersUNet3DConditionModel,
|
60 |
+
)
|
61 |
+
from diffusers.models.attention_processor import (
|
62 |
+
Attention,
|
63 |
+
AttentionProcessor,
|
64 |
+
AttnProcessor,
|
65 |
+
AttnProcessor2_0,
|
66 |
+
XFormersAttnProcessor,
|
67 |
+
)
|
68 |
+
|
69 |
+
from ..models import Model_Register
|
70 |
+
|
71 |
+
from .resnet import TemporalConvLayer
|
72 |
+
from .temporal_transformer import (
|
73 |
+
TransformerTemporalModel,
|
74 |
+
)
|
75 |
+
from .embeddings import get_2d_sincos_pos_embed, resize_spatial_position_emb
|
76 |
+
from .unet_3d_blocks import (
|
77 |
+
CrossAttnDownBlock3D,
|
78 |
+
CrossAttnUpBlock3D,
|
79 |
+
DownBlock3D,
|
80 |
+
UNetMidBlock3DCrossAttn,
|
81 |
+
UpBlock3D,
|
82 |
+
get_down_block,
|
83 |
+
get_up_block,
|
84 |
+
)
|
85 |
+
from ..data.data_util import (
|
86 |
+
adaptive_instance_normalization,
|
87 |
+
align_repeat_tensor_single_dim,
|
88 |
+
batch_adain_conditioned_tensor,
|
89 |
+
batch_concat_two_tensor_with_index,
|
90 |
+
concat_two_tensor,
|
91 |
+
concat_two_tensor_with_index,
|
92 |
+
)
|
93 |
+
from .attention_processor import BaseIPAttnProcessor
|
94 |
+
from .attention_processor import ReferEmbFuseAttention
|
95 |
+
from .transformer_2d import Transformer2DModel
|
96 |
+
from .attention import BasicTransformerBlock
|
97 |
+
|
98 |
+
|
99 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
100 |
+
|
101 |
+
# if is_torch_version(">=", "1.9.0"):
|
102 |
+
# _LOW_CPU_MEM_USAGE_DEFAULT = True
|
103 |
+
# else:
|
104 |
+
# _LOW_CPU_MEM_USAGE_DEFAULT = False
|
105 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
106 |
+
|
107 |
+
if is_accelerate_available():
|
108 |
+
import accelerate
|
109 |
+
from accelerate.utils import set_module_tensor_to_device
|
110 |
+
from accelerate.utils.versions import is_torch_version
|
111 |
+
|
112 |
+
|
113 |
+
import safetensors
|
114 |
+
|
115 |
+
|
116 |
+
def hack_t2i_sd_layer_attn_with_ip(
|
117 |
+
unet: nn.Module,
|
118 |
+
self_attn_class: BaseIPAttnProcessor = None,
|
119 |
+
cross_attn_class: BaseIPAttnProcessor = None,
|
120 |
+
):
|
121 |
+
attn_procs = {}
|
122 |
+
for name in unet.attn_processors.keys():
|
123 |
+
if "temp_attentions" in name or "transformer_in" in name:
|
124 |
+
continue
|
125 |
+
if name.endswith("attn1.processor") and self_attn_class is not None:
|
126 |
+
attn_procs[name] = self_attn_class()
|
127 |
+
if unet.print_idx == 0:
|
128 |
+
logger.debug(
|
129 |
+
f"hack attn_processor of {name} to {attn_procs[name].__class__.__name__}"
|
130 |
+
)
|
131 |
+
elif name.endswith("attn2.processor") and cross_attn_class is not None:
|
132 |
+
attn_procs[name] = cross_attn_class()
|
133 |
+
if unet.print_idx == 0:
|
134 |
+
logger.debug(
|
135 |
+
f"hack attn_processor of {name} to {attn_procs[name].__class__.__name__}"
|
136 |
+
)
|
137 |
+
unet.set_attn_processor(attn_procs, strict=False)
|
138 |
+
|
139 |
+
|
140 |
+
def convert_2D_to_3D(
|
141 |
+
module_names,
|
142 |
+
valid_modules=(
|
143 |
+
"CrossAttnDownBlock2D",
|
144 |
+
"CrossAttnUpBlock2D",
|
145 |
+
"DownBlock2D",
|
146 |
+
"UNetMidBlock2DCrossAttn",
|
147 |
+
"UpBlock2D",
|
148 |
+
),
|
149 |
+
):
|
150 |
+
if not isinstance(module_names, list):
|
151 |
+
return module_names.replace("2D", "3D")
|
152 |
+
|
153 |
+
return_modules = []
|
154 |
+
for module_name in module_names:
|
155 |
+
if module_name in valid_modules:
|
156 |
+
return_modules.append(module_name.replace("2D", "3D"))
|
157 |
+
else:
|
158 |
+
return_modules.append(module_name)
|
159 |
+
return return_modules
|
160 |
+
|
161 |
+
|
162 |
+
def insert_spatial_self_attn_idx(unet):
|
163 |
+
pass
|
164 |
+
|
165 |
+
|
166 |
+
@dataclass
|
167 |
+
class UNet3DConditionOutput(BaseOutput):
|
168 |
+
"""
|
169 |
+
The output of [`UNet3DConditionModel`].
|
170 |
+
|
171 |
+
Args:
|
172 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
173 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
174 |
+
"""
|
175 |
+
|
176 |
+
sample: torch.FloatTensor
|
177 |
+
|
178 |
+
|
179 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
180 |
+
r"""
|
181 |
+
UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
182 |
+
and returns sample shaped output.
|
183 |
+
|
184 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
185 |
+
implements for all the models (such as downloading or saving, etc.)
|
186 |
+
|
187 |
+
Parameters:
|
188 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
189 |
+
Height and width of input/output sample.
|
190 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
191 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
192 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
193 |
+
The tuple of downsample blocks to use.
|
194 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
195 |
+
The tuple of upsample blocks to use.
|
196 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
197 |
+
The tuple of output channels for each block.
|
198 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
199 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
200 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
201 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
202 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
203 |
+
If `None`, it will skip the normalization and activation layers in post-processing
|
204 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
205 |
+
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
206 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
207 |
+
"""
|
208 |
+
|
209 |
+
_supports_gradient_checkpointing = True
|
210 |
+
print_idx = 0
|
211 |
+
|
212 |
+
@register_to_config
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
sample_size: Optional[int] = None,
|
216 |
+
in_channels: int = 4,
|
217 |
+
out_channels: int = 4,
|
218 |
+
down_block_types: Tuple[str] = (
|
219 |
+
"CrossAttnDownBlock3D",
|
220 |
+
"CrossAttnDownBlock3D",
|
221 |
+
"CrossAttnDownBlock3D",
|
222 |
+
"DownBlock3D",
|
223 |
+
),
|
224 |
+
up_block_types: Tuple[str] = (
|
225 |
+
"UpBlock3D",
|
226 |
+
"CrossAttnUpBlock3D",
|
227 |
+
"CrossAttnUpBlock3D",
|
228 |
+
"CrossAttnUpBlock3D",
|
229 |
+
),
|
230 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
231 |
+
layers_per_block: int = 2,
|
232 |
+
downsample_padding: int = 1,
|
233 |
+
mid_block_scale_factor: float = 1,
|
234 |
+
act_fn: str = "silu",
|
235 |
+
norm_num_groups: Optional[int] = 32,
|
236 |
+
norm_eps: float = 1e-5,
|
237 |
+
cross_attention_dim: int = 1024,
|
238 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
239 |
+
temporal_conv_block: str = "TemporalConvLayer",
|
240 |
+
temporal_transformer: str = "TransformerTemporalModel",
|
241 |
+
need_spatial_position_emb: bool = False,
|
242 |
+
need_transformer_in: bool = True,
|
243 |
+
need_t2i_ip_adapter: bool = False, # self_attn, t2i.attn1
|
244 |
+
need_adain_temporal_cond: bool = False,
|
245 |
+
t2i_ip_adapter_attn_processor: str = "NonParamT2ISelfReferenceXFormersAttnProcessor",
|
246 |
+
keep_vision_condtion: bool = False,
|
247 |
+
use_anivv1_cfg: bool = False,
|
248 |
+
resnet_2d_skip_time_act: bool = False,
|
249 |
+
need_zero_vis_cond_temb: bool = True,
|
250 |
+
norm_spatial_length: bool = False,
|
251 |
+
spatial_max_length: int = 2048,
|
252 |
+
need_refer_emb: bool = False,
|
253 |
+
ip_adapter_cross_attn: bool = False, # cross_attn, t2i.attn2
|
254 |
+
t2i_crossattn_ip_adapter_attn_processor: str = "T2IReferencenetIPAdapterXFormersAttnProcessor",
|
255 |
+
need_t2i_facein: bool = False,
|
256 |
+
need_t2i_ip_adapter_face: bool = False,
|
257 |
+
need_vis_cond_mask: bool = False,
|
258 |
+
):
|
259 |
+
"""_summary_
|
260 |
+
|
261 |
+
Args:
|
262 |
+
sample_size (Optional[int], optional): _description_. Defaults to None.
|
263 |
+
in_channels (int, optional): _description_. Defaults to 4.
|
264 |
+
out_channels (int, optional): _description_. Defaults to 4.
|
265 |
+
down_block_types (Tuple[str], optional): _description_. Defaults to ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ).
|
266 |
+
up_block_types (Tuple[str], optional): _description_. Defaults to ( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", ).
|
267 |
+
block_out_channels (Tuple[int], optional): _description_. Defaults to (320, 640, 1280, 1280).
|
268 |
+
layers_per_block (int, optional): _description_. Defaults to 2.
|
269 |
+
downsample_padding (int, optional): _description_. Defaults to 1.
|
270 |
+
mid_block_scale_factor (float, optional): _description_. Defaults to 1.
|
271 |
+
act_fn (str, optional): _description_. Defaults to "silu".
|
272 |
+
norm_num_groups (Optional[int], optional): _description_. Defaults to 32.
|
273 |
+
norm_eps (float, optional): _description_. Defaults to 1e-5.
|
274 |
+
cross_attention_dim (int, optional): _description_. Defaults to 1024.
|
275 |
+
attention_head_dim (Union[int, Tuple[int]], optional): _description_. Defaults to 8.
|
276 |
+
temporal_conv_block (str, optional): 3D卷积字符串,需要注册在 Model_Register. Defaults to "TemporalConvLayer".
|
277 |
+
temporal_transformer (str, optional): 时序 Transformer block字符串,需要定义在 Model_Register. Defaults to "TransformerTemporalModel".
|
278 |
+
need_spatial_position_emb (bool, optional): 是否需要 spatial hw 的emb,需要配合 thw attn使用. Defaults to False.
|
279 |
+
need_transformer_in (bool, optional): 是否需要 第一个 temporal_transformer_block. Defaults to True.
|
280 |
+
need_t2i_ip_adapter (bool, optional): T2I 模块是否需要面向视觉条件帧的 attn. Defaults to False.
|
281 |
+
need_adain_temporal_cond (bool, optional): 是否需要面向首帧 使用Adain. Defaults to False.
|
282 |
+
t2i_ip_adapter_attn_processor (str, optional):
|
283 |
+
t2i attn_processor的优化版,需配合need_t2i_ip_adapter使用,
|
284 |
+
有 NonParam 表示无参ReferenceOnly-attn,没有表示有参 IpAdapter.
|
285 |
+
Defaults to "NonParamT2ISelfReferenceXFormersAttnProcessor".
|
286 |
+
keep_vision_condtion (bool, optional): 是否对视觉条件帧不加 timestep emb. Defaults to False.
|
287 |
+
use_anivv1_cfg (bool, optional): 一些基本配置 是否延续AnivV设计. Defaults to False.
|
288 |
+
resnet_2d_skip_time_act (bool, optional): 配合use_anivv1_cfg,修改 transformer 2d block. Defaults to False.
|
289 |
+
need_zero_vis_cond_temb (bool, optional): 目前无效参数. Defaults to True.
|
290 |
+
norm_spatial_length (bool, optional): 是否需要 norm_spatial_length,只有当 need_spatial_position_emb= True时,才有效. Defaults to False.
|
291 |
+
spatial_max_length (int, optional): 归一化长度. Defaults to 2048.
|
292 |
+
|
293 |
+
Raises:
|
294 |
+
ValueError: _description_
|
295 |
+
ValueError: _description_
|
296 |
+
ValueError: _description_
|
297 |
+
"""
|
298 |
+
super(UNet3DConditionModel, self).__init__()
|
299 |
+
self.keep_vision_condtion = keep_vision_condtion
|
300 |
+
self.use_anivv1_cfg = use_anivv1_cfg
|
301 |
+
self.sample_size = sample_size
|
302 |
+
self.resnet_2d_skip_time_act = resnet_2d_skip_time_act
|
303 |
+
self.need_zero_vis_cond_temb = need_zero_vis_cond_temb
|
304 |
+
self.norm_spatial_length = norm_spatial_length
|
305 |
+
self.spatial_max_length = spatial_max_length
|
306 |
+
self.need_refer_emb = need_refer_emb
|
307 |
+
self.ip_adapter_cross_attn = ip_adapter_cross_attn
|
308 |
+
self.need_t2i_facein = need_t2i_facein
|
309 |
+
self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
|
310 |
+
|
311 |
+
logger.debug(f"need_t2i_ip_adapter_face={need_t2i_ip_adapter_face}")
|
312 |
+
# Check inputs
|
313 |
+
if len(down_block_types) != len(up_block_types):
|
314 |
+
raise ValueError(
|
315 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
316 |
+
)
|
317 |
+
|
318 |
+
if len(block_out_channels) != len(down_block_types):
|
319 |
+
raise ValueError(
|
320 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
321 |
+
)
|
322 |
+
|
323 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
|
324 |
+
down_block_types
|
325 |
+
):
|
326 |
+
raise ValueError(
|
327 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
328 |
+
)
|
329 |
+
|
330 |
+
# input
|
331 |
+
conv_in_kernel = 3
|
332 |
+
conv_out_kernel = 3
|
333 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
334 |
+
self.conv_in = nn.Conv2d(
|
335 |
+
in_channels,
|
336 |
+
block_out_channels[0],
|
337 |
+
kernel_size=conv_in_kernel,
|
338 |
+
padding=conv_in_padding,
|
339 |
+
)
|
340 |
+
|
341 |
+
# time
|
342 |
+
time_embed_dim = block_out_channels[0] * 4
|
343 |
+
self.time_proj = Timesteps(block_out_channels[0], True, 0)
|
344 |
+
timestep_input_dim = block_out_channels[0]
|
345 |
+
|
346 |
+
self.time_embedding = TimestepEmbedding(
|
347 |
+
timestep_input_dim,
|
348 |
+
time_embed_dim,
|
349 |
+
act_fn=act_fn,
|
350 |
+
)
|
351 |
+
if use_anivv1_cfg:
|
352 |
+
self.time_nonlinearity = nn.SiLU()
|
353 |
+
|
354 |
+
# frame
|
355 |
+
frame_embed_dim = block_out_channels[0] * 4
|
356 |
+
self.frame_proj = Timesteps(block_out_channels[0], True, 0)
|
357 |
+
frame_input_dim = block_out_channels[0]
|
358 |
+
if temporal_transformer is not None:
|
359 |
+
self.frame_embedding = TimestepEmbedding(
|
360 |
+
frame_input_dim,
|
361 |
+
frame_embed_dim,
|
362 |
+
act_fn=act_fn,
|
363 |
+
)
|
364 |
+
else:
|
365 |
+
self.frame_embedding = None
|
366 |
+
if use_anivv1_cfg:
|
367 |
+
self.femb_nonlinearity = nn.SiLU()
|
368 |
+
|
369 |
+
# spatial_position_emb
|
370 |
+
self.need_spatial_position_emb = need_spatial_position_emb
|
371 |
+
if need_spatial_position_emb:
|
372 |
+
self.spatial_position_input_dim = block_out_channels[0] * 2
|
373 |
+
self.spatial_position_embed_dim = block_out_channels[0] * 4
|
374 |
+
|
375 |
+
self.spatial_position_embedding = TimestepEmbedding(
|
376 |
+
self.spatial_position_input_dim,
|
377 |
+
self.spatial_position_embed_dim,
|
378 |
+
act_fn=act_fn,
|
379 |
+
)
|
380 |
+
|
381 |
+
# 从模型注册表中获取 模型类
|
382 |
+
temporal_conv_block = (
|
383 |
+
Model_Register[temporal_conv_block]
|
384 |
+
if isinstance(temporal_conv_block, str)
|
385 |
+
and temporal_conv_block.lower() != "none"
|
386 |
+
else None
|
387 |
+
)
|
388 |
+
self.need_transformer_in = need_transformer_in
|
389 |
+
|
390 |
+
temporal_transformer = (
|
391 |
+
Model_Register[temporal_transformer]
|
392 |
+
if isinstance(temporal_transformer, str)
|
393 |
+
and temporal_transformer.lower() != "none"
|
394 |
+
else None
|
395 |
+
)
|
396 |
+
self.need_vis_cond_mask = need_vis_cond_mask
|
397 |
+
|
398 |
+
if need_transformer_in and temporal_transformer is not None:
|
399 |
+
self.transformer_in = temporal_transformer(
|
400 |
+
num_attention_heads=attention_head_dim,
|
401 |
+
attention_head_dim=block_out_channels[0] // attention_head_dim,
|
402 |
+
in_channels=block_out_channels[0],
|
403 |
+
num_layers=1,
|
404 |
+
femb_channels=frame_embed_dim,
|
405 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
406 |
+
cross_attention_dim=cross_attention_dim,
|
407 |
+
)
|
408 |
+
|
409 |
+
# class embedding
|
410 |
+
self.down_blocks = nn.ModuleList([])
|
411 |
+
self.up_blocks = nn.ModuleList([])
|
412 |
+
|
413 |
+
if isinstance(attention_head_dim, int):
|
414 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
415 |
+
|
416 |
+
self.need_t2i_ip_adapter = need_t2i_ip_adapter
|
417 |
+
# 确定T2I Attn 是否加入 ReferenceOnly机制或Ipadaper机制
|
418 |
+
# TODO:有待更好的实现机制,
|
419 |
+
need_t2i_ip_adapter_param = (
|
420 |
+
t2i_ip_adapter_attn_processor is not None
|
421 |
+
and "NonParam" not in t2i_ip_adapter_attn_processor
|
422 |
+
and need_t2i_ip_adapter
|
423 |
+
)
|
424 |
+
self.need_adain_temporal_cond = need_adain_temporal_cond
|
425 |
+
self.t2i_ip_adapter_attn_processor = t2i_ip_adapter_attn_processor
|
426 |
+
|
427 |
+
if need_refer_emb:
|
428 |
+
self.first_refer_emb_attns = ReferEmbFuseAttention(
|
429 |
+
query_dim=block_out_channels[0],
|
430 |
+
heads=attention_head_dim[0],
|
431 |
+
dim_head=block_out_channels[0] // attention_head_dim[0],
|
432 |
+
dropout=0,
|
433 |
+
bias=False,
|
434 |
+
cross_attention_dim=None,
|
435 |
+
upcast_attention=False,
|
436 |
+
)
|
437 |
+
self.mid_block_refer_emb_attns = ReferEmbFuseAttention(
|
438 |
+
query_dim=block_out_channels[-1],
|
439 |
+
heads=attention_head_dim[-1],
|
440 |
+
dim_head=block_out_channels[-1] // attention_head_dim[-1],
|
441 |
+
dropout=0,
|
442 |
+
bias=False,
|
443 |
+
cross_attention_dim=None,
|
444 |
+
upcast_attention=False,
|
445 |
+
)
|
446 |
+
else:
|
447 |
+
self.first_refer_emb_attns = None
|
448 |
+
self.mid_block_refer_emb_attns = None
|
449 |
+
# down
|
450 |
+
output_channel = block_out_channels[0]
|
451 |
+
self.layers_per_block = layers_per_block
|
452 |
+
self.block_out_channels = block_out_channels
|
453 |
+
for i, down_block_type in enumerate(down_block_types):
|
454 |
+
input_channel = output_channel
|
455 |
+
output_channel = block_out_channels[i]
|
456 |
+
is_final_block = i == len(block_out_channels) - 1
|
457 |
+
|
458 |
+
down_block = get_down_block(
|
459 |
+
down_block_type,
|
460 |
+
num_layers=layers_per_block,
|
461 |
+
in_channels=input_channel,
|
462 |
+
out_channels=output_channel,
|
463 |
+
temb_channels=time_embed_dim,
|
464 |
+
femb_channels=frame_embed_dim,
|
465 |
+
add_downsample=not is_final_block,
|
466 |
+
resnet_eps=norm_eps,
|
467 |
+
resnet_act_fn=act_fn,
|
468 |
+
resnet_groups=norm_num_groups,
|
469 |
+
cross_attention_dim=cross_attention_dim,
|
470 |
+
attn_num_head_channels=attention_head_dim[i],
|
471 |
+
downsample_padding=downsample_padding,
|
472 |
+
dual_cross_attention=False,
|
473 |
+
temporal_conv_block=temporal_conv_block,
|
474 |
+
temporal_transformer=temporal_transformer,
|
475 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
476 |
+
need_t2i_ip_adapter=need_t2i_ip_adapter_param,
|
477 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
478 |
+
need_t2i_facein=need_t2i_facein,
|
479 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
480 |
+
need_adain_temporal_cond=need_adain_temporal_cond,
|
481 |
+
resnet_2d_skip_time_act=resnet_2d_skip_time_act,
|
482 |
+
need_refer_emb=need_refer_emb,
|
483 |
+
)
|
484 |
+
self.down_blocks.append(down_block)
|
485 |
+
# mid
|
486 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
487 |
+
in_channels=block_out_channels[-1],
|
488 |
+
temb_channels=time_embed_dim,
|
489 |
+
femb_channels=frame_embed_dim,
|
490 |
+
resnet_eps=norm_eps,
|
491 |
+
resnet_act_fn=act_fn,
|
492 |
+
output_scale_factor=mid_block_scale_factor,
|
493 |
+
cross_attention_dim=cross_attention_dim,
|
494 |
+
attn_num_head_channels=attention_head_dim[-1],
|
495 |
+
resnet_groups=norm_num_groups,
|
496 |
+
dual_cross_attention=False,
|
497 |
+
temporal_conv_block=temporal_conv_block,
|
498 |
+
temporal_transformer=temporal_transformer,
|
499 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
500 |
+
need_t2i_ip_adapter=need_t2i_ip_adapter_param,
|
501 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
502 |
+
need_t2i_facein=need_t2i_facein,
|
503 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
504 |
+
need_adain_temporal_cond=need_adain_temporal_cond,
|
505 |
+
resnet_2d_skip_time_act=resnet_2d_skip_time_act,
|
506 |
+
)
|
507 |
+
|
508 |
+
# count how many layers upsample the images
|
509 |
+
self.num_upsamplers = 0
|
510 |
+
|
511 |
+
# up
|
512 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
513 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
514 |
+
|
515 |
+
output_channel = reversed_block_out_channels[0]
|
516 |
+
for i, up_block_type in enumerate(up_block_types):
|
517 |
+
is_final_block = i == len(block_out_channels) - 1
|
518 |
+
|
519 |
+
prev_output_channel = output_channel
|
520 |
+
output_channel = reversed_block_out_channels[i]
|
521 |
+
input_channel = reversed_block_out_channels[
|
522 |
+
min(i + 1, len(block_out_channels) - 1)
|
523 |
+
]
|
524 |
+
|
525 |
+
# add upsample block for all BUT final layer
|
526 |
+
if not is_final_block:
|
527 |
+
add_upsample = True
|
528 |
+
self.num_upsamplers += 1
|
529 |
+
else:
|
530 |
+
add_upsample = False
|
531 |
+
|
532 |
+
up_block = get_up_block(
|
533 |
+
up_block_type,
|
534 |
+
num_layers=layers_per_block + 1,
|
535 |
+
in_channels=input_channel,
|
536 |
+
out_channels=output_channel,
|
537 |
+
prev_output_channel=prev_output_channel,
|
538 |
+
temb_channels=time_embed_dim,
|
539 |
+
femb_channels=frame_embed_dim,
|
540 |
+
add_upsample=add_upsample,
|
541 |
+
resnet_eps=norm_eps,
|
542 |
+
resnet_act_fn=act_fn,
|
543 |
+
resnet_groups=norm_num_groups,
|
544 |
+
cross_attention_dim=cross_attention_dim,
|
545 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
546 |
+
dual_cross_attention=False,
|
547 |
+
temporal_conv_block=temporal_conv_block,
|
548 |
+
temporal_transformer=temporal_transformer,
|
549 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
550 |
+
need_t2i_ip_adapter=need_t2i_ip_adapter_param,
|
551 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
552 |
+
need_t2i_facein=need_t2i_facein,
|
553 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
554 |
+
need_adain_temporal_cond=need_adain_temporal_cond,
|
555 |
+
resnet_2d_skip_time_act=resnet_2d_skip_time_act,
|
556 |
+
)
|
557 |
+
self.up_blocks.append(up_block)
|
558 |
+
prev_output_channel = output_channel
|
559 |
+
|
560 |
+
# out
|
561 |
+
if norm_num_groups is not None:
|
562 |
+
self.conv_norm_out = nn.GroupNorm(
|
563 |
+
num_channels=block_out_channels[0],
|
564 |
+
num_groups=norm_num_groups,
|
565 |
+
eps=norm_eps,
|
566 |
+
)
|
567 |
+
self.conv_act = nn.SiLU()
|
568 |
+
else:
|
569 |
+
self.conv_norm_out = None
|
570 |
+
self.conv_act = None
|
571 |
+
|
572 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
573 |
+
self.conv_out = nn.Conv2d(
|
574 |
+
block_out_channels[0],
|
575 |
+
out_channels,
|
576 |
+
kernel_size=conv_out_kernel,
|
577 |
+
padding=conv_out_padding,
|
578 |
+
)
|
579 |
+
self.insert_spatial_self_attn_idx()
|
580 |
+
|
581 |
+
# 根据需要hack attn_processor,实现ip_adapter等功能
|
582 |
+
if need_t2i_ip_adapter or ip_adapter_cross_attn:
|
583 |
+
hack_t2i_sd_layer_attn_with_ip(
|
584 |
+
self,
|
585 |
+
self_attn_class=Model_Register[t2i_ip_adapter_attn_processor]
|
586 |
+
if t2i_ip_adapter_attn_processor is not None and need_t2i_ip_adapter
|
587 |
+
else None,
|
588 |
+
cross_attn_class=Model_Register[t2i_crossattn_ip_adapter_attn_processor]
|
589 |
+
if t2i_crossattn_ip_adapter_attn_processor is not None
|
590 |
+
and (
|
591 |
+
ip_adapter_cross_attn or need_t2i_facein or need_t2i_ip_adapter_face
|
592 |
+
)
|
593 |
+
else None,
|
594 |
+
)
|
595 |
+
# logger.debug(pformat(self.attn_processors))
|
596 |
+
|
597 |
+
# 非参数AttnProcessor,就不需要to_k_ip、to_v_ip参数了
|
598 |
+
if (
|
599 |
+
t2i_ip_adapter_attn_processor is None
|
600 |
+
or "NonParam" in t2i_ip_adapter_attn_processor
|
601 |
+
):
|
602 |
+
need_t2i_ip_adapter = False
|
603 |
+
|
604 |
+
if self.print_idx == 0:
|
605 |
+
logger.debug("Unet3Model Parameters")
|
606 |
+
# logger.debug(pformat(self.__dict__))
|
607 |
+
|
608 |
+
# 会在 set_skip_temporal_layers 设置 skip_refer_downblock_emb
|
609 |
+
# 当为 True 时,会跳过 referencenet_block_emb的影响,主要用于首帧生成
|
610 |
+
self.skip_refer_downblock_emb = False
|
611 |
+
|
612 |
+
@property
|
613 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
614 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
615 |
+
r"""
|
616 |
+
Returns:
|
617 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
618 |
+
indexed by its weight name.
|
619 |
+
"""
|
620 |
+
# set recursively
|
621 |
+
processors = {}
|
622 |
+
|
623 |
+
def fn_recursive_add_processors(
|
624 |
+
name: str,
|
625 |
+
module: torch.nn.Module,
|
626 |
+
processors: Dict[str, AttentionProcessor],
|
627 |
+
):
|
628 |
+
if hasattr(module, "set_processor"):
|
629 |
+
processors[f"{name}.processor"] = module.processor
|
630 |
+
|
631 |
+
for sub_name, child in module.named_children():
|
632 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
633 |
+
|
634 |
+
return processors
|
635 |
+
|
636 |
+
for name, module in self.named_children():
|
637 |
+
fn_recursive_add_processors(name, module, processors)
|
638 |
+
|
639 |
+
return processors
|
640 |
+
|
641 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
642 |
+
def set_attention_slice(self, slice_size):
|
643 |
+
r"""
|
644 |
+
Enable sliced attention computation.
|
645 |
+
|
646 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
647 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
648 |
+
|
649 |
+
Args:
|
650 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
651 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
652 |
+
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
|
653 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
654 |
+
must be a multiple of `slice_size`.
|
655 |
+
"""
|
656 |
+
sliceable_head_dims = []
|
657 |
+
|
658 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
659 |
+
if hasattr(module, "set_attention_slice"):
|
660 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
661 |
+
|
662 |
+
for child in module.children():
|
663 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
664 |
+
|
665 |
+
# retrieve number of attention layers
|
666 |
+
for module in self.children():
|
667 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
668 |
+
|
669 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
670 |
+
|
671 |
+
if slice_size == "auto":
|
672 |
+
# half the attention head size is usually a good trade-off between
|
673 |
+
# speed and memory
|
674 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
675 |
+
elif slice_size == "max":
|
676 |
+
# make smallest slice possible
|
677 |
+
slice_size = num_sliceable_layers * [1]
|
678 |
+
|
679 |
+
slice_size = (
|
680 |
+
num_sliceable_layers * [slice_size]
|
681 |
+
if not isinstance(slice_size, list)
|
682 |
+
else slice_size
|
683 |
+
)
|
684 |
+
|
685 |
+
if len(slice_size) != len(sliceable_head_dims):
|
686 |
+
raise ValueError(
|
687 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
688 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
689 |
+
)
|
690 |
+
|
691 |
+
for i in range(len(slice_size)):
|
692 |
+
size = slice_size[i]
|
693 |
+
dim = sliceable_head_dims[i]
|
694 |
+
if size is not None and size > dim:
|
695 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
696 |
+
|
697 |
+
# Recursively walk through all the children.
|
698 |
+
# Any children which exposes the set_attention_slice method
|
699 |
+
# gets the message
|
700 |
+
def fn_recursive_set_attention_slice(
|
701 |
+
module: torch.nn.Module, slice_size: List[int]
|
702 |
+
):
|
703 |
+
if hasattr(module, "set_attention_slice"):
|
704 |
+
module.set_attention_slice(slice_size.pop())
|
705 |
+
|
706 |
+
for child in module.children():
|
707 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
708 |
+
|
709 |
+
reversed_slice_size = list(reversed(slice_size))
|
710 |
+
for module in self.children():
|
711 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
712 |
+
|
713 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
714 |
+
def set_attn_processor(
|
715 |
+
self,
|
716 |
+
processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
|
717 |
+
strict: bool = True,
|
718 |
+
):
|
719 |
+
r"""
|
720 |
+
Parameters:
|
721 |
+
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
|
722 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
723 |
+
of **all** `Attention` layers.
|
724 |
+
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
|
725 |
+
|
726 |
+
"""
|
727 |
+
count = len(self.attn_processors.keys())
|
728 |
+
|
729 |
+
if isinstance(processor, dict) and len(processor) != count and strict:
|
730 |
+
raise ValueError(
|
731 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
732 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
733 |
+
)
|
734 |
+
|
735 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
736 |
+
if hasattr(module, "set_processor"):
|
737 |
+
if not isinstance(processor, dict):
|
738 |
+
logger.debug(
|
739 |
+
f"module {name} set attn processor {processor.__class__.__name__}"
|
740 |
+
)
|
741 |
+
module.set_processor(processor)
|
742 |
+
else:
|
743 |
+
if f"{name}.processor" in processor:
|
744 |
+
logger.debug(
|
745 |
+
"module {} set attn processor {}".format(
|
746 |
+
name, processor[f"{name}.processor"].__class__.__name__
|
747 |
+
)
|
748 |
+
)
|
749 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
750 |
+
else:
|
751 |
+
logger.debug(
|
752 |
+
f"module {name} has no new target attn_processor, still use {module.processor.__class__.__name__} "
|
753 |
+
)
|
754 |
+
for sub_name, child in module.named_children():
|
755 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
756 |
+
|
757 |
+
for name, module in self.named_children():
|
758 |
+
fn_recursive_attn_processor(name, module, processor)
|
759 |
+
|
760 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
761 |
+
def set_default_attn_processor(self):
|
762 |
+
"""
|
763 |
+
Disables custom attention processors and sets the default attention implementation.
|
764 |
+
"""
|
765 |
+
self.set_attn_processor(AttnProcessor())
|
766 |
+
|
767 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
768 |
+
if isinstance(
|
769 |
+
module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)
|
770 |
+
):
|
771 |
+
module.gradient_checkpointing = value
|
772 |
+
|
773 |
+
def forward(
|
774 |
+
self,
|
775 |
+
sample: torch.FloatTensor,
|
776 |
+
timestep: Union[torch.Tensor, float, int],
|
777 |
+
encoder_hidden_states: torch.Tensor,
|
778 |
+
class_labels: Optional[torch.Tensor] = None,
|
779 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
780 |
+
attention_mask: Optional[torch.Tensor] = None,
|
781 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
782 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
783 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
784 |
+
return_dict: bool = True,
|
785 |
+
sample_index: torch.LongTensor = None,
|
786 |
+
vision_condition_frames_sample: torch.Tensor = None,
|
787 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
788 |
+
sample_frame_rate: int = 10,
|
789 |
+
skip_temporal_layers: bool = None,
|
790 |
+
frame_index: torch.LongTensor = None,
|
791 |
+
down_block_refer_embs: Optional[Tuple[torch.Tensor]] = None,
|
792 |
+
mid_block_refer_emb: Optional[torch.Tensor] = None,
|
793 |
+
refer_self_attn_emb: Optional[List[torch.Tensor]] = None,
|
794 |
+
refer_self_attn_emb_mode: Literal["read", "write"] = "read",
|
795 |
+
vision_clip_emb: torch.Tensor = None,
|
796 |
+
ip_adapter_scale: float = 1.0,
|
797 |
+
face_emb: torch.Tensor = None,
|
798 |
+
facein_scale: float = 1.0,
|
799 |
+
ip_adapter_face_emb: torch.Tensor = None,
|
800 |
+
ip_adapter_face_scale: float = 1.0,
|
801 |
+
do_classifier_free_guidance: bool = False,
|
802 |
+
pose_guider_emb: torch.Tensor = None,
|
803 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
804 |
+
"""_summary_
|
805 |
+
|
806 |
+
Args:
|
807 |
+
sample (torch.FloatTensor): _description_
|
808 |
+
timestep (Union[torch.Tensor, float, int]): _description_
|
809 |
+
encoder_hidden_states (torch.Tensor): _description_
|
810 |
+
class_labels (Optional[torch.Tensor], optional): _description_. Defaults to None.
|
811 |
+
timestep_cond (Optional[torch.Tensor], optional): _description_. Defaults to None.
|
812 |
+
attention_mask (Optional[torch.Tensor], optional): _description_. Defaults to None.
|
813 |
+
cross_attention_kwargs (Optional[Dict[str, Any]], optional): _description_. Defaults to None.
|
814 |
+
down_block_additional_residuals (Optional[Tuple[torch.Tensor]], optional): _description_. Defaults to None.
|
815 |
+
mid_block_additional_residual (Optional[torch.Tensor], optional): _description_. Defaults to None.
|
816 |
+
return_dict (bool, optional): _description_. Defaults to True.
|
817 |
+
sample_index (torch.LongTensor, optional): _description_. Defaults to None.
|
818 |
+
vision_condition_frames_sample (torch.Tensor, optional): _description_. Defaults to None.
|
819 |
+
vision_conditon_frames_sample_index (torch.LongTensor, optional): _description_. Defaults to None.
|
820 |
+
sample_frame_rate (int, optional): _description_. Defaults to 10.
|
821 |
+
skip_temporal_layers (bool, optional): _description_. Defaults to None.
|
822 |
+
frame_index (torch.LongTensor, optional): _description_. Defaults to None.
|
823 |
+
up_block_additional_residual (Optional[torch.Tensor], optional): 用于up_block的 参考latent. Defaults to None.
|
824 |
+
down_block_refer_embs (Optional[torch.Tensor], optional): 用于 download 的 参考latent. Defaults to None.
|
825 |
+
how_fuse_referencenet_emb (Literal, optional): 如何融合 参考 latent. Defaults to ["add", "attn"]="add".
|
826 |
+
add: 要求 additional_latent 和 latent hw 同尺寸. hw of addtional_latent should be same as of latent
|
827 |
+
attn: concat bt*h1w1*c and bt*h2w2*c into bt*(h1w1+h2w2)*c, and then as key,value into attn
|
828 |
+
Raises:
|
829 |
+
ValueError: _description_
|
830 |
+
|
831 |
+
Returns:
|
832 |
+
Union[UNet3DConditionOutput, Tuple]: _description_
|
833 |
+
"""
|
834 |
+
|
835 |
+
if skip_temporal_layers is not None:
|
836 |
+
self.set_skip_temporal_layers(skip_temporal_layers)
|
837 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
838 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
839 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
840 |
+
# on the fly if necessary.
|
841 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
842 |
+
|
843 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
844 |
+
forward_upsample_size = False
|
845 |
+
upsample_size = None
|
846 |
+
|
847 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
848 |
+
# logger.debug("Forward upsample size to force interpolation output size.")
|
849 |
+
forward_upsample_size = True
|
850 |
+
|
851 |
+
# prepare attention_mask
|
852 |
+
if attention_mask is not None:
|
853 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
854 |
+
attention_mask = attention_mask.unsqueeze(1)
|
855 |
+
|
856 |
+
# 1. time
|
857 |
+
timesteps = timestep
|
858 |
+
if not torch.is_tensor(timesteps):
|
859 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
860 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
861 |
+
is_mps = sample.device.type == "mps"
|
862 |
+
if isinstance(timestep, float):
|
863 |
+
dtype = torch.float32 if is_mps else torch.float64
|
864 |
+
else:
|
865 |
+
dtype = torch.int32 if is_mps else torch.int64
|
866 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
867 |
+
elif len(timesteps.shape) == 0:
|
868 |
+
timesteps = timesteps[None].to(sample.device)
|
869 |
+
|
870 |
+
batch_size = sample.shape[0]
|
871 |
+
|
872 |
+
# when vision_condition_frames_sample is not None and vision_conditon_frames_sample_index is not None
|
873 |
+
# if not None, b c t h w -> b c (t + n_content ) h w
|
874 |
+
|
875 |
+
if vision_condition_frames_sample is not None:
|
876 |
+
sample = batch_concat_two_tensor_with_index(
|
877 |
+
sample,
|
878 |
+
sample_index,
|
879 |
+
vision_condition_frames_sample,
|
880 |
+
vision_conditon_frames_sample_index,
|
881 |
+
dim=2,
|
882 |
+
)
|
883 |
+
|
884 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
885 |
+
batch_size, channel, num_frames, height, width = sample.shape
|
886 |
+
|
887 |
+
# 准备 timestep emb
|
888 |
+
timesteps = timesteps.expand(sample.shape[0])
|
889 |
+
temb = self.time_proj(timesteps)
|
890 |
+
temb = temb.to(dtype=self.dtype)
|
891 |
+
emb = self.time_embedding(temb, timestep_cond)
|
892 |
+
if self.use_anivv1_cfg:
|
893 |
+
emb = self.time_nonlinearity(emb)
|
894 |
+
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
895 |
+
|
896 |
+
# 一致性保持,使条件时序帧的 首帧 timesteps emb 为 0,即不影响视觉条件帧
|
897 |
+
# keep consistent with the first frame of vision condition frames
|
898 |
+
if (
|
899 |
+
self.keep_vision_condtion
|
900 |
+
and num_frames > 1
|
901 |
+
and sample_index is not None
|
902 |
+
and vision_conditon_frames_sample_index is not None
|
903 |
+
):
|
904 |
+
emb = rearrange(emb, "(b t) d -> b t d", t=num_frames)
|
905 |
+
emb[:, vision_conditon_frames_sample_index, :] = 0
|
906 |
+
emb = rearrange(emb, "b t d->(b t) d")
|
907 |
+
|
908 |
+
# temporal positional embedding
|
909 |
+
femb = None
|
910 |
+
if self.temporal_transformer is not None:
|
911 |
+
if frame_index is None:
|
912 |
+
frame_index = torch.arange(
|
913 |
+
num_frames, dtype=torch.long, device=sample.device
|
914 |
+
)
|
915 |
+
if self.use_anivv1_cfg:
|
916 |
+
frame_index = (frame_index * sample_frame_rate).to(dtype=torch.long)
|
917 |
+
femb = self.frame_proj(frame_index)
|
918 |
+
if self.print_idx == 0:
|
919 |
+
logger.debug(
|
920 |
+
f"unet prepare frame_index, {femb.shape}, {batch_size}"
|
921 |
+
)
|
922 |
+
femb = repeat(femb, "t d-> b t d", b=batch_size)
|
923 |
+
else:
|
924 |
+
# b t -> b t d
|
925 |
+
assert frame_index.ndim == 2, ValueError(
|
926 |
+
"ndim of given frame_index should be 2, but {frame_index.ndim}"
|
927 |
+
)
|
928 |
+
femb = torch.stack(
|
929 |
+
[self.frame_proj(frame_index[i]) for i in range(batch_size)], dim=0
|
930 |
+
)
|
931 |
+
if self.temporal_transformer is not None:
|
932 |
+
femb = femb.to(dtype=self.dtype)
|
933 |
+
femb = self.frame_embedding(
|
934 |
+
femb,
|
935 |
+
)
|
936 |
+
if self.use_anivv1_cfg:
|
937 |
+
femb = self.femb_nonlinearity(femb)
|
938 |
+
if encoder_hidden_states.ndim == 3:
|
939 |
+
encoder_hidden_states = align_repeat_tensor_single_dim(
|
940 |
+
encoder_hidden_states, target_length=emb.shape[0], dim=0
|
941 |
+
)
|
942 |
+
elif encoder_hidden_states.ndim == 4:
|
943 |
+
encoder_hidden_states = rearrange(
|
944 |
+
encoder_hidden_states, "b t n q-> (b t) n q"
|
945 |
+
)
|
946 |
+
else:
|
947 |
+
raise ValueError(
|
948 |
+
f"only support ndim in [3, 4], but given {encoder_hidden_states.ndim}"
|
949 |
+
)
|
950 |
+
if vision_clip_emb is not None:
|
951 |
+
if vision_clip_emb.ndim == 4:
|
952 |
+
vision_clip_emb = rearrange(vision_clip_emb, "b t n q-> (b t) n q")
|
953 |
+
# 准备 hw 层面的 spatial positional embedding
|
954 |
+
# prepare spatial_position_emb
|
955 |
+
if self.need_spatial_position_emb:
|
956 |
+
# height * width, self.spatial_position_input_dim
|
957 |
+
spatial_position_emb = get_2d_sincos_pos_embed(
|
958 |
+
embed_dim=self.spatial_position_input_dim,
|
959 |
+
grid_size_w=width,
|
960 |
+
grid_size_h=height,
|
961 |
+
cls_token=False,
|
962 |
+
norm_length=self.norm_spatial_length,
|
963 |
+
max_length=self.spatial_max_length,
|
964 |
+
)
|
965 |
+
spatial_position_emb = torch.from_numpy(spatial_position_emb).to(
|
966 |
+
device=sample.device, dtype=self.dtype
|
967 |
+
)
|
968 |
+
# height * width, self.spatial_position_embed_dim
|
969 |
+
spatial_position_emb = self.spatial_position_embedding(spatial_position_emb)
|
970 |
+
else:
|
971 |
+
spatial_position_emb = None
|
972 |
+
|
973 |
+
# prepare cross_attention_kwargs,ReferenceOnly/IpAdapter的attn_processor需要这些参数 进行 latenst和viscond_latents拆分运算
|
974 |
+
if (
|
975 |
+
self.need_t2i_ip_adapter
|
976 |
+
or self.ip_adapter_cross_attn
|
977 |
+
or self.need_t2i_facein
|
978 |
+
or self.need_t2i_ip_adapter_face
|
979 |
+
):
|
980 |
+
if cross_attention_kwargs is None:
|
981 |
+
cross_attention_kwargs = {}
|
982 |
+
cross_attention_kwargs["num_frames"] = num_frames
|
983 |
+
cross_attention_kwargs[
|
984 |
+
"do_classifier_free_guidance"
|
985 |
+
] = do_classifier_free_guidance
|
986 |
+
cross_attention_kwargs["sample_index"] = sample_index
|
987 |
+
cross_attention_kwargs[
|
988 |
+
"vision_conditon_frames_sample_index"
|
989 |
+
] = vision_conditon_frames_sample_index
|
990 |
+
if self.ip_adapter_cross_attn:
|
991 |
+
cross_attention_kwargs["vision_clip_emb"] = vision_clip_emb
|
992 |
+
cross_attention_kwargs["ip_adapter_scale"] = ip_adapter_scale
|
993 |
+
if self.need_t2i_facein:
|
994 |
+
if self.print_idx == 0:
|
995 |
+
logger.debug(
|
996 |
+
f"face_emb={type(face_emb)}, facein_scale={facein_scale}"
|
997 |
+
)
|
998 |
+
cross_attention_kwargs["face_emb"] = face_emb
|
999 |
+
cross_attention_kwargs["facein_scale"] = facein_scale
|
1000 |
+
if self.need_t2i_ip_adapter_face:
|
1001 |
+
if self.print_idx == 0:
|
1002 |
+
logger.debug(
|
1003 |
+
f"ip_adapter_face_emb={type(ip_adapter_face_emb)}, ip_adapter_face_scale={ip_adapter_face_scale}"
|
1004 |
+
)
|
1005 |
+
cross_attention_kwargs["ip_adapter_face_emb"] = ip_adapter_face_emb
|
1006 |
+
cross_attention_kwargs["ip_adapter_face_scale"] = ip_adapter_face_scale
|
1007 |
+
# 2. pre-process
|
1008 |
+
sample = rearrange(sample, "b c t h w -> (b t) c h w")
|
1009 |
+
sample = self.conv_in(sample)
|
1010 |
+
|
1011 |
+
if pose_guider_emb is not None:
|
1012 |
+
if self.print_idx == 0:
|
1013 |
+
logger.debug(
|
1014 |
+
f"sample={sample.shape}, pose_guider_emb={pose_guider_emb.shape}"
|
1015 |
+
)
|
1016 |
+
sample = sample + pose_guider_emb
|
1017 |
+
|
1018 |
+
if self.print_idx == 0:
|
1019 |
+
logger.debug(f"after conv in sample={sample.mean()}")
|
1020 |
+
if spatial_position_emb is not None:
|
1021 |
+
if self.print_idx == 0:
|
1022 |
+
logger.debug(
|
1023 |
+
f"unet3d, transformer_in, spatial_position_emb={spatial_position_emb.shape}"
|
1024 |
+
)
|
1025 |
+
if self.print_idx == 0:
|
1026 |
+
logger.debug(
|
1027 |
+
f"unet vision_conditon_frames_sample_index, {type(vision_conditon_frames_sample_index)}",
|
1028 |
+
)
|
1029 |
+
if vision_conditon_frames_sample_index is not None:
|
1030 |
+
if self.print_idx == 0:
|
1031 |
+
logger.debug(
|
1032 |
+
f"vision_conditon_frames_sample_index shape {vision_conditon_frames_sample_index.shape}",
|
1033 |
+
)
|
1034 |
+
if self.print_idx == 0:
|
1035 |
+
logger.debug(f"unet sample_index {type(sample_index)}")
|
1036 |
+
if sample_index is not None:
|
1037 |
+
if self.print_idx == 0:
|
1038 |
+
logger.debug(f"sample_index shape {sample_index.shape}")
|
1039 |
+
if self.need_transformer_in:
|
1040 |
+
if self.print_idx == 0:
|
1041 |
+
logger.debug(f"unet3d, transformer_in, sample={sample.shape}")
|
1042 |
+
sample = self.transformer_in(
|
1043 |
+
sample,
|
1044 |
+
femb=femb,
|
1045 |
+
num_frames=num_frames,
|
1046 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1047 |
+
encoder_hidden_states=encoder_hidden_states,
|
1048 |
+
sample_index=sample_index,
|
1049 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
1050 |
+
spatial_position_emb=spatial_position_emb,
|
1051 |
+
).sample
|
1052 |
+
if (
|
1053 |
+
self.need_refer_emb
|
1054 |
+
and down_block_refer_embs is not None
|
1055 |
+
and not self.skip_refer_downblock_emb
|
1056 |
+
):
|
1057 |
+
if self.print_idx == 0:
|
1058 |
+
logger.debug(
|
1059 |
+
f"self.first_refer_emb_attns, {self.first_refer_emb_attns.__class__.__name__} {down_block_refer_embs[0].shape}"
|
1060 |
+
)
|
1061 |
+
sample = self.first_refer_emb_attns(
|
1062 |
+
sample, down_block_refer_embs[0], num_frames=num_frames
|
1063 |
+
)
|
1064 |
+
if self.print_idx == 0:
|
1065 |
+
logger.debug(
|
1066 |
+
f"first_refer_emb_attns, sample is_leaf={sample.is_leaf}, requires_grad={sample.requires_grad}, down_block_refer_embs, {down_block_refer_embs[0].is_leaf}, {down_block_refer_embs[0].requires_grad},"
|
1067 |
+
)
|
1068 |
+
else:
|
1069 |
+
if self.print_idx == 0:
|
1070 |
+
logger.debug(f"first_refer_emb_attns, no this step")
|
1071 |
+
# 将 refer_self_attn_emb 转化成字典,增加一个当前index,表示block 的对应关系
|
1072 |
+
# convert refer_self_attn_emb to dict, add a current index to represent the corresponding relationship of the block
|
1073 |
+
|
1074 |
+
# 3. down
|
1075 |
+
down_block_res_samples = (sample,)
|
1076 |
+
for i_down_block, downsample_block in enumerate(self.down_blocks):
|
1077 |
+
# 使用 attn 的方式 来融合 refer_emb,这里是准备 downblock 对应的 refer_emb
|
1078 |
+
# fuse refer_emb with attn, here is to prepare the refer_emb corresponding to downblock
|
1079 |
+
if (
|
1080 |
+
not self.need_refer_emb
|
1081 |
+
or down_block_refer_embs is None
|
1082 |
+
or self.skip_refer_downblock_emb
|
1083 |
+
):
|
1084 |
+
this_down_block_refer_embs = None
|
1085 |
+
if self.print_idx == 0:
|
1086 |
+
logger.debug(
|
1087 |
+
f"{i_down_block}, prepare this_down_block_refer_embs, is None"
|
1088 |
+
)
|
1089 |
+
else:
|
1090 |
+
is_final_block = i_down_block == len(self.block_out_channels) - 1
|
1091 |
+
num_block = self.layers_per_block + int(not is_final_block * 1)
|
1092 |
+
this_downblock_start_idx = 1 + num_block * i_down_block
|
1093 |
+
this_down_block_refer_embs = down_block_refer_embs[
|
1094 |
+
this_downblock_start_idx : this_downblock_start_idx + num_block
|
1095 |
+
]
|
1096 |
+
if self.print_idx == 0:
|
1097 |
+
logger.debug(
|
1098 |
+
f"prepare this_down_block_refer_embs, {len(this_down_block_refer_embs)}, {this_down_block_refer_embs[0].shape}"
|
1099 |
+
)
|
1100 |
+
if self.print_idx == 0:
|
1101 |
+
logger.debug(f"downsample_block {i_down_block}, sample={sample.mean()}")
|
1102 |
+
if (
|
1103 |
+
hasattr(downsample_block, "has_cross_attention")
|
1104 |
+
and downsample_block.has_cross_attention
|
1105 |
+
):
|
1106 |
+
sample, res_samples = downsample_block(
|
1107 |
+
hidden_states=sample,
|
1108 |
+
temb=emb,
|
1109 |
+
femb=femb,
|
1110 |
+
encoder_hidden_states=encoder_hidden_states,
|
1111 |
+
attention_mask=attention_mask,
|
1112 |
+
num_frames=num_frames,
|
1113 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1114 |
+
sample_index=sample_index,
|
1115 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
1116 |
+
spatial_position_emb=spatial_position_emb,
|
1117 |
+
refer_embs=this_down_block_refer_embs,
|
1118 |
+
refer_self_attn_emb=refer_self_attn_emb,
|
1119 |
+
refer_self_attn_emb_mode=refer_self_attn_emb_mode,
|
1120 |
+
)
|
1121 |
+
else:
|
1122 |
+
sample, res_samples = downsample_block(
|
1123 |
+
hidden_states=sample,
|
1124 |
+
temb=emb,
|
1125 |
+
femb=femb,
|
1126 |
+
num_frames=num_frames,
|
1127 |
+
sample_index=sample_index,
|
1128 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
1129 |
+
spatial_position_emb=spatial_position_emb,
|
1130 |
+
refer_embs=this_down_block_refer_embs,
|
1131 |
+
refer_self_attn_emb=refer_self_attn_emb,
|
1132 |
+
refer_self_attn_emb_mode=refer_self_attn_emb_mode,
|
1133 |
+
)
|
1134 |
+
|
1135 |
+
# resize spatial_position_emb
|
1136 |
+
if self.need_spatial_position_emb:
|
1137 |
+
has_downblock = i_down_block < len(self.down_blocks) - 1
|
1138 |
+
if has_downblock:
|
1139 |
+
spatial_position_emb = resize_spatial_position_emb(
|
1140 |
+
spatial_position_emb,
|
1141 |
+
scale=0.5,
|
1142 |
+
height=sample.shape[2] * 2,
|
1143 |
+
width=sample.shape[3] * 2,
|
1144 |
+
)
|
1145 |
+
down_block_res_samples += res_samples
|
1146 |
+
if down_block_additional_residuals is not None:
|
1147 |
+
new_down_block_res_samples = ()
|
1148 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
1149 |
+
down_block_res_samples, down_block_additional_residuals
|
1150 |
+
):
|
1151 |
+
down_block_res_sample = (
|
1152 |
+
down_block_res_sample + down_block_additional_residual
|
1153 |
+
)
|
1154 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
1155 |
+
|
1156 |
+
down_block_res_samples = new_down_block_res_samples
|
1157 |
+
|
1158 |
+
# 4. mid
|
1159 |
+
if self.mid_block is not None:
|
1160 |
+
sample = self.mid_block(
|
1161 |
+
hidden_states=sample,
|
1162 |
+
temb=emb,
|
1163 |
+
femb=femb,
|
1164 |
+
encoder_hidden_states=encoder_hidden_states,
|
1165 |
+
attention_mask=attention_mask,
|
1166 |
+
num_frames=num_frames,
|
1167 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1168 |
+
sample_index=sample_index,
|
1169 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
1170 |
+
spatial_position_emb=spatial_position_emb,
|
1171 |
+
refer_self_attn_emb=refer_self_attn_emb,
|
1172 |
+
refer_self_attn_emb_mode=refer_self_attn_emb_mode,
|
1173 |
+
)
|
1174 |
+
# 使用 attn 的方式 来融合 mid_block_refer_emb
|
1175 |
+
# fuse mid_block_refer_emb with attn
|
1176 |
+
if (
|
1177 |
+
self.mid_block_refer_emb_attns is not None
|
1178 |
+
and mid_block_refer_emb is not None
|
1179 |
+
and not self.skip_refer_downblock_emb
|
1180 |
+
):
|
1181 |
+
if self.print_idx == 0:
|
1182 |
+
logger.debug(
|
1183 |
+
f"self.mid_block_refer_emb_attns={self.mid_block_refer_emb_attns}, mid_block_refer_emb={mid_block_refer_emb.shape}"
|
1184 |
+
)
|
1185 |
+
sample = self.mid_block_refer_emb_attns(
|
1186 |
+
sample, mid_block_refer_emb, num_frames=num_frames
|
1187 |
+
)
|
1188 |
+
if self.print_idx == 0:
|
1189 |
+
logger.debug(
|
1190 |
+
f"mid_block_refer_emb_attns, sample is_leaf={sample.is_leaf}, requires_grad={sample.requires_grad}, mid_block_refer_emb, {mid_block_refer_emb[0].is_leaf}, {mid_block_refer_emb[0].requires_grad},"
|
1191 |
+
)
|
1192 |
+
else:
|
1193 |
+
if self.print_idx == 0:
|
1194 |
+
logger.debug(f"mid_block_refer_emb_attns, no this step")
|
1195 |
+
if mid_block_additional_residual is not None:
|
1196 |
+
sample = sample + mid_block_additional_residual
|
1197 |
+
|
1198 |
+
# 5. up
|
1199 |
+
for i_up_block, upsample_block in enumerate(self.up_blocks):
|
1200 |
+
is_final_block = i_up_block == len(self.up_blocks) - 1
|
1201 |
+
|
1202 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1203 |
+
down_block_res_samples = down_block_res_samples[
|
1204 |
+
: -len(upsample_block.resnets)
|
1205 |
+
]
|
1206 |
+
|
1207 |
+
# if we have not reached the final block and need to forward the
|
1208 |
+
# upsample size, we do it here
|
1209 |
+
if not is_final_block and forward_upsample_size:
|
1210 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1211 |
+
|
1212 |
+
if (
|
1213 |
+
hasattr(upsample_block, "has_cross_attention")
|
1214 |
+
and upsample_block.has_cross_attention
|
1215 |
+
):
|
1216 |
+
sample = upsample_block(
|
1217 |
+
hidden_states=sample,
|
1218 |
+
temb=emb,
|
1219 |
+
femb=femb,
|
1220 |
+
res_hidden_states_tuple=res_samples,
|
1221 |
+
encoder_hidden_states=encoder_hidden_states,
|
1222 |
+
upsample_size=upsample_size,
|
1223 |
+
attention_mask=attention_mask,
|
1224 |
+
num_frames=num_frames,
|
1225 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1226 |
+
sample_index=sample_index,
|
1227 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
1228 |
+
spatial_position_emb=spatial_position_emb,
|
1229 |
+
refer_self_attn_emb=refer_self_attn_emb,
|
1230 |
+
refer_self_attn_emb_mode=refer_self_attn_emb_mode,
|
1231 |
+
)
|
1232 |
+
else:
|
1233 |
+
sample = upsample_block(
|
1234 |
+
hidden_states=sample,
|
1235 |
+
temb=emb,
|
1236 |
+
femb=femb,
|
1237 |
+
res_hidden_states_tuple=res_samples,
|
1238 |
+
upsample_size=upsample_size,
|
1239 |
+
num_frames=num_frames,
|
1240 |
+
sample_index=sample_index,
|
1241 |
+
vision_conditon_frames_sample_index=vision_conditon_frames_sample_index,
|
1242 |
+
spatial_position_emb=spatial_position_emb,
|
1243 |
+
refer_self_attn_emb=refer_self_attn_emb,
|
1244 |
+
refer_self_attn_emb_mode=refer_self_attn_emb_mode,
|
1245 |
+
)
|
1246 |
+
# resize spatial_position_emb
|
1247 |
+
if self.need_spatial_position_emb:
|
1248 |
+
has_upblock = i_up_block < len(self.up_blocks) - 1
|
1249 |
+
if has_upblock:
|
1250 |
+
spatial_position_emb = resize_spatial_position_emb(
|
1251 |
+
spatial_position_emb,
|
1252 |
+
scale=2,
|
1253 |
+
height=int(sample.shape[2] / 2),
|
1254 |
+
width=int(sample.shape[3] / 2),
|
1255 |
+
)
|
1256 |
+
|
1257 |
+
# 6. post-process
|
1258 |
+
if self.conv_norm_out:
|
1259 |
+
sample = self.conv_norm_out(sample)
|
1260 |
+
sample = self.conv_act(sample)
|
1261 |
+
|
1262 |
+
sample = self.conv_out(sample)
|
1263 |
+
sample = rearrange(sample, "(b t) c h w -> b c t h w", t=num_frames)
|
1264 |
+
|
1265 |
+
# if self.need_adain_temporal_cond and num_frames > 1:
|
1266 |
+
# sample = batch_adain_conditioned_tensor(
|
1267 |
+
# sample,
|
1268 |
+
# num_frames=num_frames,
|
1269 |
+
# need_style_fidelity=False,
|
1270 |
+
# src_index=sample_index,
|
1271 |
+
# dst_index=vision_conditon_frames_sample_index,
|
1272 |
+
# )
|
1273 |
+
self.print_idx += 1
|
1274 |
+
|
1275 |
+
if skip_temporal_layers is not None:
|
1276 |
+
self.set_skip_temporal_layers(not skip_temporal_layers)
|
1277 |
+
if not return_dict:
|
1278 |
+
return (sample,)
|
1279 |
+
else:
|
1280 |
+
return UNet3DConditionOutput(sample=sample)
|
1281 |
+
|
1282 |
+
# from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/modeling_utils.py#L328
|
1283 |
+
@classmethod
|
1284 |
+
def from_pretrained_2d(
|
1285 |
+
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs
|
1286 |
+
):
|
1287 |
+
r"""
|
1288 |
+
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
1289 |
+
|
1290 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
1291 |
+
the model, you should first set it back in training mode with `model.train()`.
|
1292 |
+
|
1293 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
1294 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
1295 |
+
task.
|
1296 |
+
|
1297 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
1298 |
+
weights are discarded.
|
1299 |
+
|
1300 |
+
Parameters:
|
1301 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
1302 |
+
Can be either:
|
1303 |
+
|
1304 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
1305 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
1306 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
1307 |
+
`./my_model_directory/`.
|
1308 |
+
|
1309 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1310 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
1311 |
+
standard cache should not be used.
|
1312 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
1313 |
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
1314 |
+
will be automatically derived from the model's weights.
|
1315 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
1316 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1317 |
+
cached versions if they exist.
|
1318 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
1319 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
1320 |
+
file exists.
|
1321 |
+
proxies (`Dict[str, str]`, *optional*):
|
1322 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
1323 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1324 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
1325 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
1326 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
1327 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
1328 |
+
use_auth_token (`str` or *bool*, *optional*):
|
1329 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
1330 |
+
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
1331 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
1332 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
1333 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
1334 |
+
identifier allowed by git.
|
1335 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
1336 |
+
Load the model weights from a Flax checkpoint save file.
|
1337 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
1338 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
1339 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
1340 |
+
|
1341 |
+
mirror (`str`, *optional*):
|
1342 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
1343 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
1344 |
+
Please refer to the mirror site for more information.
|
1345 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
1346 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
1347 |
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
1348 |
+
same device.
|
1349 |
+
|
1350 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
1351 |
+
more information about each option see [designing a device
|
1352 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
1353 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
1354 |
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
1355 |
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
1356 |
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
1357 |
+
setting this argument to `True` will raise an error.
|
1358 |
+
variant (`str`, *optional*):
|
1359 |
+
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
1360 |
+
ignored when using `from_flax`.
|
1361 |
+
use_safetensors (`bool`, *optional* ):
|
1362 |
+
If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to
|
1363 |
+
`None` (the default). The pipeline will load using `safetensors` if safetensors weights are available
|
1364 |
+
*and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
|
1365 |
+
|
1366 |
+
<Tip>
|
1367 |
+
|
1368 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
1369 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
1370 |
+
|
1371 |
+
</Tip>
|
1372 |
+
|
1373 |
+
<Tip>
|
1374 |
+
|
1375 |
+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
1376 |
+
this method in a firewalled environment.
|
1377 |
+
|
1378 |
+
</Tip>
|
1379 |
+
|
1380 |
+
"""
|
1381 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
1382 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
1383 |
+
force_download = kwargs.pop("force_download", False)
|
1384 |
+
from_flax = kwargs.pop("from_flax", False)
|
1385 |
+
resume_download = kwargs.pop("resume_download", False)
|
1386 |
+
proxies = kwargs.pop("proxies", None)
|
1387 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
1388 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
1389 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
1390 |
+
revision = kwargs.pop("revision", None)
|
1391 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
1392 |
+
subfolder = kwargs.pop("subfolder", None)
|
1393 |
+
device_map = kwargs.pop("device_map", None)
|
1394 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
1395 |
+
variant = kwargs.pop("variant", None)
|
1396 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1397 |
+
strict = kwargs.pop("strict", True)
|
1398 |
+
|
1399 |
+
allow_pickle = False
|
1400 |
+
if use_safetensors is None:
|
1401 |
+
allow_pickle = True
|
1402 |
+
|
1403 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
1404 |
+
low_cpu_mem_usage = False
|
1405 |
+
logger.warning(
|
1406 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
1407 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
1408 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
1409 |
+
" install accelerate\n```\n."
|
1410 |
+
)
|
1411 |
+
|
1412 |
+
if device_map is not None and not is_accelerate_available():
|
1413 |
+
raise NotImplementedError(
|
1414 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
1415 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
1416 |
+
)
|
1417 |
+
|
1418 |
+
# Check if we can handle device_map and dispatching the weights
|
1419 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
1420 |
+
raise NotImplementedError(
|
1421 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1422 |
+
" `device_map=None`."
|
1423 |
+
)
|
1424 |
+
|
1425 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
1426 |
+
raise NotImplementedError(
|
1427 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1428 |
+
" `low_cpu_mem_usage=False`."
|
1429 |
+
)
|
1430 |
+
|
1431 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
1432 |
+
raise ValueError(
|
1433 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
1434 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
1435 |
+
)
|
1436 |
+
|
1437 |
+
# Load config if we don't provide a configuration
|
1438 |
+
config_path = pretrained_model_name_or_path
|
1439 |
+
|
1440 |
+
user_agent = {
|
1441 |
+
"diffusers": __version__,
|
1442 |
+
"file_type": "model",
|
1443 |
+
"framework": "pytorch",
|
1444 |
+
}
|
1445 |
+
|
1446 |
+
# load config
|
1447 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
1448 |
+
config_path,
|
1449 |
+
cache_dir=cache_dir,
|
1450 |
+
return_unused_kwargs=True,
|
1451 |
+
return_commit_hash=True,
|
1452 |
+
force_download=force_download,
|
1453 |
+
resume_download=resume_download,
|
1454 |
+
proxies=proxies,
|
1455 |
+
local_files_only=local_files_only,
|
1456 |
+
use_auth_token=use_auth_token,
|
1457 |
+
revision=revision,
|
1458 |
+
subfolder=subfolder,
|
1459 |
+
device_map=device_map,
|
1460 |
+
user_agent=user_agent,
|
1461 |
+
**kwargs,
|
1462 |
+
)
|
1463 |
+
|
1464 |
+
config["_class_name"] = cls.__name__
|
1465 |
+
config["down_block_types"] = convert_2D_to_3D(config["down_block_types"])
|
1466 |
+
if "mid_block_type" in config:
|
1467 |
+
config["mid_block_type"] = convert_2D_to_3D(config["mid_block_type"])
|
1468 |
+
else:
|
1469 |
+
config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
|
1470 |
+
config["up_block_types"] = convert_2D_to_3D(config["up_block_types"])
|
1471 |
+
|
1472 |
+
# load model
|
1473 |
+
model_file = None
|
1474 |
+
if from_flax:
|
1475 |
+
model_file = _get_model_file(
|
1476 |
+
pretrained_model_name_or_path,
|
1477 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
1478 |
+
cache_dir=cache_dir,
|
1479 |
+
force_download=force_download,
|
1480 |
+
resume_download=resume_download,
|
1481 |
+
proxies=proxies,
|
1482 |
+
local_files_only=local_files_only,
|
1483 |
+
use_auth_token=use_auth_token,
|
1484 |
+
revision=revision,
|
1485 |
+
subfolder=subfolder,
|
1486 |
+
user_agent=user_agent,
|
1487 |
+
commit_hash=commit_hash,
|
1488 |
+
)
|
1489 |
+
model = cls.from_config(config, **unused_kwargs)
|
1490 |
+
|
1491 |
+
# Convert the weights
|
1492 |
+
from diffusers.models.modeling_pytorch_flax_utils import (
|
1493 |
+
load_flax_checkpoint_in_pytorch_model,
|
1494 |
+
)
|
1495 |
+
|
1496 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
1497 |
+
else:
|
1498 |
+
try:
|
1499 |
+
model_file = _get_model_file(
|
1500 |
+
pretrained_model_name_or_path,
|
1501 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
1502 |
+
cache_dir=cache_dir,
|
1503 |
+
force_download=force_download,
|
1504 |
+
resume_download=resume_download,
|
1505 |
+
proxies=proxies,
|
1506 |
+
local_files_only=local_files_only,
|
1507 |
+
use_auth_token=use_auth_token,
|
1508 |
+
revision=revision,
|
1509 |
+
subfolder=subfolder,
|
1510 |
+
user_agent=user_agent,
|
1511 |
+
commit_hash=commit_hash,
|
1512 |
+
)
|
1513 |
+
except IOError as e:
|
1514 |
+
if not allow_pickle:
|
1515 |
+
raise e
|
1516 |
+
pass
|
1517 |
+
if model_file is None:
|
1518 |
+
model_file = _get_model_file(
|
1519 |
+
pretrained_model_name_or_path,
|
1520 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
1521 |
+
cache_dir=cache_dir,
|
1522 |
+
force_download=force_download,
|
1523 |
+
resume_download=resume_download,
|
1524 |
+
proxies=proxies,
|
1525 |
+
local_files_only=local_files_only,
|
1526 |
+
use_auth_token=use_auth_token,
|
1527 |
+
revision=revision,
|
1528 |
+
subfolder=subfolder,
|
1529 |
+
user_agent=user_agent,
|
1530 |
+
commit_hash=commit_hash,
|
1531 |
+
)
|
1532 |
+
|
1533 |
+
if low_cpu_mem_usage:
|
1534 |
+
# Instantiate model with empty weights
|
1535 |
+
with accelerate.init_empty_weights():
|
1536 |
+
model = cls.from_config(config, **unused_kwargs)
|
1537 |
+
|
1538 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
1539 |
+
if device_map is None:
|
1540 |
+
param_device = "cpu"
|
1541 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
1542 |
+
# move the params from meta device to cpu
|
1543 |
+
missing_keys = set(model.state_dict().keys()) - set(
|
1544 |
+
state_dict.keys()
|
1545 |
+
)
|
1546 |
+
if len(missing_keys) > 0:
|
1547 |
+
if strict:
|
1548 |
+
raise ValueError(
|
1549 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
1550 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
1551 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
1552 |
+
" those weights or else make sure your checkpoint file is correct."
|
1553 |
+
)
|
1554 |
+
else:
|
1555 |
+
logger.warning(
|
1556 |
+
f"model{cls} has no target pretrained paramter from {pretrained_model_name_or_path}, {', '.join(missing_keys)}"
|
1557 |
+
)
|
1558 |
+
|
1559 |
+
empty_state_dict = model.state_dict()
|
1560 |
+
for param_name, param in state_dict.items():
|
1561 |
+
accepts_dtype = "dtype" in set(
|
1562 |
+
inspect.signature(
|
1563 |
+
set_module_tensor_to_device
|
1564 |
+
).parameters.keys()
|
1565 |
+
)
|
1566 |
+
|
1567 |
+
if empty_state_dict[param_name].shape != param.shape:
|
1568 |
+
raise ValueError(
|
1569 |
+
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
1570 |
+
)
|
1571 |
+
|
1572 |
+
if accepts_dtype:
|
1573 |
+
set_module_tensor_to_device(
|
1574 |
+
model,
|
1575 |
+
param_name,
|
1576 |
+
param_device,
|
1577 |
+
value=param,
|
1578 |
+
dtype=torch_dtype,
|
1579 |
+
)
|
1580 |
+
else:
|
1581 |
+
set_module_tensor_to_device(
|
1582 |
+
model, param_name, param_device, value=param
|
1583 |
+
)
|
1584 |
+
else: # else let accelerate handle loading and dispatching.
|
1585 |
+
# Load weights and dispatch according to the device_map
|
1586 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
1587 |
+
accelerate.load_checkpoint_and_dispatch(
|
1588 |
+
model, model_file, device_map, dtype=torch_dtype
|
1589 |
+
)
|
1590 |
+
|
1591 |
+
loading_info = {
|
1592 |
+
"missing_keys": [],
|
1593 |
+
"unexpected_keys": [],
|
1594 |
+
"mismatched_keys": [],
|
1595 |
+
"error_msgs": [],
|
1596 |
+
}
|
1597 |
+
else:
|
1598 |
+
model = cls.from_config(config, **unused_kwargs)
|
1599 |
+
|
1600 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
1601 |
+
|
1602 |
+
(
|
1603 |
+
model,
|
1604 |
+
missing_keys,
|
1605 |
+
unexpected_keys,
|
1606 |
+
mismatched_keys,
|
1607 |
+
error_msgs,
|
1608 |
+
) = cls._load_pretrained_model(
|
1609 |
+
model,
|
1610 |
+
state_dict,
|
1611 |
+
model_file,
|
1612 |
+
pretrained_model_name_or_path,
|
1613 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
1614 |
+
)
|
1615 |
+
|
1616 |
+
loading_info = {
|
1617 |
+
"missing_keys": missing_keys,
|
1618 |
+
"unexpected_keys": unexpected_keys,
|
1619 |
+
"mismatched_keys": mismatched_keys,
|
1620 |
+
"error_msgs": error_msgs,
|
1621 |
+
}
|
1622 |
+
|
1623 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
1624 |
+
raise ValueError(
|
1625 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
1626 |
+
)
|
1627 |
+
elif torch_dtype is not None:
|
1628 |
+
model = model.to(torch_dtype)
|
1629 |
+
|
1630 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
1631 |
+
|
1632 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
1633 |
+
model.eval()
|
1634 |
+
if output_loading_info:
|
1635 |
+
return model, loading_info
|
1636 |
+
|
1637 |
+
return model
|
1638 |
+
|
1639 |
+
def set_skip_temporal_layers(
|
1640 |
+
self,
|
1641 |
+
valid: bool,
|
1642 |
+
) -> None: # turn 3Dunet to 2Dunet
|
1643 |
+
# Recursively walk through all the children.
|
1644 |
+
# Any children which exposes the skip_temporal_layers parameter gets the message
|
1645 |
+
|
1646 |
+
# 推断时使用参数控制refer_image和ip_adapter_image来控制,不需要这里了
|
1647 |
+
# if hasattr(self, "skip_refer_downblock_emb"):
|
1648 |
+
# self.skip_refer_downblock_emb = valid
|
1649 |
+
|
1650 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
1651 |
+
if hasattr(module, "skip_temporal_layers"):
|
1652 |
+
module.skip_temporal_layers = valid
|
1653 |
+
# if hasattr(module, "skip_refer_downblock_emb"):
|
1654 |
+
# module.skip_refer_downblock_emb = valid
|
1655 |
+
|
1656 |
+
for child in module.children():
|
1657 |
+
fn_recursive_set_mem_eff(child)
|
1658 |
+
|
1659 |
+
for module in self.children():
|
1660 |
+
if isinstance(module, torch.nn.Module):
|
1661 |
+
fn_recursive_set_mem_eff(module)
|
1662 |
+
|
1663 |
+
def insert_spatial_self_attn_idx(self):
|
1664 |
+
attns, basic_transformers = self.spatial_self_attns
|
1665 |
+
self.self_attn_num = len(attns)
|
1666 |
+
for i, (name, layer) in enumerate(attns):
|
1667 |
+
logger.debug(
|
1668 |
+
f"{self.__class__.__name__}, {i}, {name}, {layer.__class__.__name__}"
|
1669 |
+
)
|
1670 |
+
layer.spatial_self_attn_idx = i
|
1671 |
+
for i, (name, layer) in enumerate(basic_transformers):
|
1672 |
+
logger.debug(
|
1673 |
+
f"{self.__class__.__name__}, {i}, {name}, {layer.__class__.__name__}"
|
1674 |
+
)
|
1675 |
+
layer.spatial_self_attn_idx = i
|
1676 |
+
|
1677 |
+
@property
|
1678 |
+
def spatial_self_attns(
|
1679 |
+
self,
|
1680 |
+
) -> List[Tuple[str, Attention]]:
|
1681 |
+
attns, spatial_transformers = self.get_attns(
|
1682 |
+
include="attentions", exclude="temp_attentions", attn_name="attn1"
|
1683 |
+
)
|
1684 |
+
attns = sorted(attns)
|
1685 |
+
spatial_transformers = sorted(spatial_transformers)
|
1686 |
+
return attns, spatial_transformers
|
1687 |
+
|
1688 |
+
@property
|
1689 |
+
def spatial_cross_attns(
|
1690 |
+
self,
|
1691 |
+
) -> List[Tuple[str, Attention]]:
|
1692 |
+
attns, spatial_transformers = self.get_attns(
|
1693 |
+
include="attentions", exclude="temp_attentions", attn_name="attn2"
|
1694 |
+
)
|
1695 |
+
attns = sorted(attns)
|
1696 |
+
spatial_transformers = sorted(spatial_transformers)
|
1697 |
+
return attns, spatial_transformers
|
1698 |
+
|
1699 |
+
def get_attns(
|
1700 |
+
self,
|
1701 |
+
attn_name: str,
|
1702 |
+
include: str = None,
|
1703 |
+
exclude: str = None,
|
1704 |
+
) -> List[Tuple[str, Attention]]:
|
1705 |
+
r"""
|
1706 |
+
Returns:
|
1707 |
+
`dict` of attention attns: A dictionary containing all attention attns used in the model with
|
1708 |
+
indexed by its weight name.
|
1709 |
+
"""
|
1710 |
+
# set recursively
|
1711 |
+
attns = []
|
1712 |
+
spatial_transformers = []
|
1713 |
+
|
1714 |
+
def fn_recursive_add_attns(
|
1715 |
+
name: str,
|
1716 |
+
module: torch.nn.Module,
|
1717 |
+
attns: List[Tuple[str, Attention]],
|
1718 |
+
spatial_transformers: List[Tuple[str, BasicTransformerBlock]],
|
1719 |
+
):
|
1720 |
+
is_target = False
|
1721 |
+
if isinstance(module, BasicTransformerBlock) and hasattr(module, attn_name):
|
1722 |
+
is_target = True
|
1723 |
+
if include is not None:
|
1724 |
+
is_target = include in name
|
1725 |
+
if exclude is not None:
|
1726 |
+
is_target = exclude not in name
|
1727 |
+
if is_target:
|
1728 |
+
attns.append([f"{name}.{attn_name}", getattr(module, attn_name)])
|
1729 |
+
spatial_transformers.append([f"{name}", module])
|
1730 |
+
for sub_name, child in module.named_children():
|
1731 |
+
fn_recursive_add_attns(
|
1732 |
+
f"{name}.{sub_name}", child, attns, spatial_transformers
|
1733 |
+
)
|
1734 |
+
|
1735 |
+
return attns
|
1736 |
+
|
1737 |
+
for name, module in self.named_children():
|
1738 |
+
fn_recursive_add_attns(name, module, attns, spatial_transformers)
|
1739 |
+
|
1740 |
+
return attns, spatial_transformers
|
musev/models/unet_loader.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.modeling_utils import load_state_dict
|
31 |
+
from diffusers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
|
36 |
+
from ..models.unet_3d_condition import UNet3DConditionModel
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
|
41 |
+
def update_unet_with_sd(
|
42 |
+
unet: nn.Module, sd_model: Tuple[str, nn.Module], subfolder: str = "unet"
|
43 |
+
):
|
44 |
+
"""更新T2V模型中的T2I参数. update t2i parameters in t2v model
|
45 |
+
|
46 |
+
Args:
|
47 |
+
unet (nn.Module): _description_
|
48 |
+
sd_model (Tuple[str, nn.Module]): _description_
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
_type_: _description_
|
52 |
+
"""
|
53 |
+
# dtype = unet.dtype
|
54 |
+
# TODO: in this way, sd_model_path must be absolute path, to be more dynamic
|
55 |
+
if isinstance(sd_model, str):
|
56 |
+
if os.path.isdir(sd_model):
|
57 |
+
unet_state_dict = load_state_dict(
|
58 |
+
os.path.join(sd_model, subfolder, "diffusion_pytorch_model.bin"),
|
59 |
+
)
|
60 |
+
elif os.path.isfile(sd_model):
|
61 |
+
if sd_model.endswith("pth"):
|
62 |
+
unet_state_dict = torch.load(sd_model, map_location="cpu")
|
63 |
+
print(f"referencenet successful load ={sd_model} with torch.load")
|
64 |
+
else:
|
65 |
+
try:
|
66 |
+
unet_state_dict = load_state_dict(sd_model)
|
67 |
+
print(
|
68 |
+
f"referencenet successful load with {sd_model} with load_state_dict"
|
69 |
+
)
|
70 |
+
except Exception as e:
|
71 |
+
print(e)
|
72 |
+
|
73 |
+
elif isinstance(sd_model, nn.Module):
|
74 |
+
unet_state_dict = sd_model.state_dict()
|
75 |
+
else:
|
76 |
+
raise ValueError(f"given {type(sd_model)}, but only support nn.Module or str")
|
77 |
+
missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False)
|
78 |
+
assert len(unexpected) == 0, f"unet load_state_dict error, unexpected={unexpected}"
|
79 |
+
# unet.to(dtype=dtype)
|
80 |
+
return unet
|
81 |
+
|
82 |
+
|
83 |
+
def load_unet(
|
84 |
+
sd_unet_model: Tuple[str, nn.Module],
|
85 |
+
sd_model: Tuple[str, nn.Module] = None,
|
86 |
+
cross_attention_dim: int = 768,
|
87 |
+
temporal_transformer: str = "TransformerTemporalModel",
|
88 |
+
temporal_conv_block: str = "TemporalConvLayer",
|
89 |
+
need_spatial_position_emb: bool = False,
|
90 |
+
need_transformer_in: bool = True,
|
91 |
+
need_t2i_ip_adapter: bool = False,
|
92 |
+
need_adain_temporal_cond: bool = False,
|
93 |
+
t2i_ip_adapter_attn_processor: str = "IPXFormersAttnProcessor",
|
94 |
+
keep_vision_condtion: bool = False,
|
95 |
+
use_anivv1_cfg: bool = False,
|
96 |
+
resnet_2d_skip_time_act: bool = False,
|
97 |
+
dtype: torch.dtype = torch.float16,
|
98 |
+
need_zero_vis_cond_temb: bool = True,
|
99 |
+
norm_spatial_length: bool = True,
|
100 |
+
spatial_max_length: int = 2048,
|
101 |
+
need_refer_emb: bool = False,
|
102 |
+
ip_adapter_cross_attn=False,
|
103 |
+
t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor",
|
104 |
+
need_t2i_facein: bool = False,
|
105 |
+
need_t2i_ip_adapter_face: bool = False,
|
106 |
+
strict: bool = True,
|
107 |
+
):
|
108 |
+
"""通过模型名字 初始化Unet,载入预训练参数. init unet with model_name.
|
109 |
+
该部分都是通过 models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型
|
110 |
+
model is defined and trained in models.unet_3d_condition.py:UNet3DConditionModel
|
111 |
+
|
112 |
+
Args:
|
113 |
+
sd_unet_model (Tuple[str, nn.Module]): _description_
|
114 |
+
sd_model (Tuple[str, nn.Module]): _description_
|
115 |
+
cross_attention_dim (int, optional): _description_. Defaults to 768.
|
116 |
+
temporal_transformer (str, optional): _description_. Defaults to "TransformerTemporalModel".
|
117 |
+
temporal_conv_block (str, optional): _description_. Defaults to "TemporalConvLayer".
|
118 |
+
need_spatial_position_emb (bool, optional): _description_. Defaults to False.
|
119 |
+
need_transformer_in (bool, optional): _description_. Defaults to True.
|
120 |
+
need_t2i_ip_adapter (bool, optional): _description_. Defaults to False.
|
121 |
+
need_adain_temporal_cond (bool, optional): _description_. Defaults to False.
|
122 |
+
t2i_ip_adapter_attn_processor (str, optional): _description_. Defaults to "IPXFormersAttnProcessor".
|
123 |
+
keep_vision_condtion (bool, optional): _description_. Defaults to False.
|
124 |
+
use_anivv1_cfg (bool, optional): _description_. Defaults to False.
|
125 |
+
resnet_2d_skip_time_act (bool, optional): _description_. Defaults to False.
|
126 |
+
dtype (torch.dtype, optional): _description_. Defaults to torch.float16.
|
127 |
+
need_zero_vis_cond_temb (bool, optional): _description_. Defaults to True.
|
128 |
+
norm_spatial_length (bool, optional): _description_. Defaults to True.
|
129 |
+
spatial_max_length (int, optional): _description_. Defaults to 2048.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
_type_: _description_
|
133 |
+
"""
|
134 |
+
if isinstance(sd_unet_model, str):
|
135 |
+
unet = UNet3DConditionModel.from_pretrained_2d(
|
136 |
+
sd_unet_model,
|
137 |
+
subfolder="unet",
|
138 |
+
temporal_transformer=temporal_transformer,
|
139 |
+
temporal_conv_block=temporal_conv_block,
|
140 |
+
cross_attention_dim=cross_attention_dim,
|
141 |
+
need_spatial_position_emb=need_spatial_position_emb,
|
142 |
+
need_transformer_in=need_transformer_in,
|
143 |
+
need_t2i_ip_adapter=need_t2i_ip_adapter,
|
144 |
+
need_adain_temporal_cond=need_adain_temporal_cond,
|
145 |
+
t2i_ip_adapter_attn_processor=t2i_ip_adapter_attn_processor,
|
146 |
+
keep_vision_condtion=keep_vision_condtion,
|
147 |
+
use_anivv1_cfg=use_anivv1_cfg,
|
148 |
+
resnet_2d_skip_time_act=resnet_2d_skip_time_act,
|
149 |
+
torch_dtype=dtype,
|
150 |
+
need_zero_vis_cond_temb=need_zero_vis_cond_temb,
|
151 |
+
norm_spatial_length=norm_spatial_length,
|
152 |
+
spatial_max_length=spatial_max_length,
|
153 |
+
need_refer_emb=need_refer_emb,
|
154 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
155 |
+
t2i_crossattn_ip_adapter_attn_processor=t2i_crossattn_ip_adapter_attn_processor,
|
156 |
+
need_t2i_facein=need_t2i_facein,
|
157 |
+
strict=strict,
|
158 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
159 |
+
)
|
160 |
+
elif isinstance(sd_unet_model, nn.Module):
|
161 |
+
unet = sd_unet_model
|
162 |
+
if sd_model is not None:
|
163 |
+
unet = update_unet_with_sd(unet, sd_model)
|
164 |
+
return unet
|
165 |
+
|
166 |
+
|
167 |
+
def load_unet_custom_unet(
|
168 |
+
sd_unet_model: Tuple[str, nn.Module],
|
169 |
+
sd_model: Tuple[str, nn.Module],
|
170 |
+
unet_class: nn.Module,
|
171 |
+
):
|
172 |
+
"""
|
173 |
+
通过模型名字 初始化Unet,载入预训练参数. init unet with model_name.
|
174 |
+
该部分都是通过 不通过models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型
|
175 |
+
model is not defined in models.unet_3d_condition.py:UNet3DConditionModel
|
176 |
+
Args:
|
177 |
+
sd_unet_model (Tuple[str, nn.Module]): _description_
|
178 |
+
sd_model (Tuple[str, nn.Module]): _description_
|
179 |
+
unet_class (nn.Module): _description_
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
_type_: _description_
|
183 |
+
"""
|
184 |
+
if isinstance(sd_unet_model, str):
|
185 |
+
unet = unet_class.from_pretrained(
|
186 |
+
sd_unet_model,
|
187 |
+
subfolder="unet",
|
188 |
+
)
|
189 |
+
elif isinstance(sd_unet_model, nn.Module):
|
190 |
+
unet = sd_unet_model
|
191 |
+
|
192 |
+
# TODO: in this way, sd_model_path must be absolute path, to be more dynamic
|
193 |
+
if isinstance(sd_model, str):
|
194 |
+
unet_state_dict = load_state_dict(
|
195 |
+
os.path.join(sd_model, "unet/diffusion_pytorch_model.bin"),
|
196 |
+
)
|
197 |
+
elif isinstance(sd_model, nn.Module):
|
198 |
+
unet_state_dict = sd_model.state_dict()
|
199 |
+
missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False)
|
200 |
+
assert (
|
201 |
+
len(unexpected) == 0
|
202 |
+
), "unet load_state_dict error" # Load scheduler, tokenizer and models.
|
203 |
+
return unet
|
204 |
+
|
205 |
+
|
206 |
+
def load_unet_by_name(
|
207 |
+
model_name: str,
|
208 |
+
sd_unet_model: Tuple[str, nn.Module],
|
209 |
+
sd_model: Tuple[str, nn.Module] = None,
|
210 |
+
cross_attention_dim: int = 768,
|
211 |
+
dtype: torch.dtype = torch.float16,
|
212 |
+
need_t2i_facein: bool = False,
|
213 |
+
need_t2i_ip_adapter_face: bool = False,
|
214 |
+
strict: bool = True,
|
215 |
+
) -> nn.Module:
|
216 |
+
"""通过模型名字 初始化Unet,载入预训练参数. init unet with model_name.
|
217 |
+
如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义
|
218 |
+
if you want to use pretrained model with simple name, you need to define it here.
|
219 |
+
Args:
|
220 |
+
model_name (str): _description_
|
221 |
+
sd_unet_model (Tuple[str, nn.Module]): _description_
|
222 |
+
sd_model (Tuple[str, nn.Module]): _description_
|
223 |
+
cross_attention_dim (int, optional): _description_. Defaults to 768.
|
224 |
+
dtype (torch.dtype, optional): _description_. Defaults to torch.float16.
|
225 |
+
|
226 |
+
Raises:
|
227 |
+
ValueError: _description_
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
nn.Module: _description_
|
231 |
+
"""
|
232 |
+
if model_name in ["musev"]:
|
233 |
+
unet = load_unet(
|
234 |
+
sd_unet_model=sd_unet_model,
|
235 |
+
sd_model=sd_model,
|
236 |
+
need_spatial_position_emb=False,
|
237 |
+
cross_attention_dim=cross_attention_dim,
|
238 |
+
need_t2i_ip_adapter=True,
|
239 |
+
need_adain_temporal_cond=True,
|
240 |
+
t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor",
|
241 |
+
dtype=dtype,
|
242 |
+
)
|
243 |
+
elif model_name in [
|
244 |
+
"musev_referencenet",
|
245 |
+
"musev_referencenet_pose",
|
246 |
+
]:
|
247 |
+
unet = load_unet(
|
248 |
+
sd_unet_model=sd_unet_model,
|
249 |
+
sd_model=sd_model,
|
250 |
+
cross_attention_dim=cross_attention_dim,
|
251 |
+
temporal_conv_block="TemporalConvLayer",
|
252 |
+
need_transformer_in=False,
|
253 |
+
temporal_transformer="TransformerTemporalModel",
|
254 |
+
use_anivv1_cfg=True,
|
255 |
+
resnet_2d_skip_time_act=True,
|
256 |
+
need_t2i_ip_adapter=True,
|
257 |
+
need_adain_temporal_cond=True,
|
258 |
+
keep_vision_condtion=True,
|
259 |
+
t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor",
|
260 |
+
dtype=dtype,
|
261 |
+
need_refer_emb=True,
|
262 |
+
need_zero_vis_cond_temb=True,
|
263 |
+
ip_adapter_cross_attn=True,
|
264 |
+
t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor",
|
265 |
+
need_t2i_facein=need_t2i_facein,
|
266 |
+
strict=strict,
|
267 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
raise ValueError(
|
271 |
+
f"unsupport model_name={model_name}, only support musev, musev_referencenet, musev_referencenet_pose"
|
272 |
+
)
|
273 |
+
return unet
|
musev/pipelines/__init__.py
ADDED
File without changes
|
musev/pipelines/context.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: Adapted from cli
|
2 |
+
import math
|
3 |
+
from typing import Callable, List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from mmcm.utils.itertools_util import generate_sample_idxs
|
8 |
+
|
9 |
+
# copy from https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/context.py
|
10 |
+
|
11 |
+
|
12 |
+
def ordered_halving(val):
|
13 |
+
bin_str = f"{val:064b}"
|
14 |
+
bin_flip = bin_str[::-1]
|
15 |
+
as_int = int(bin_flip, 2)
|
16 |
+
|
17 |
+
return as_int / (1 << 64)
|
18 |
+
|
19 |
+
|
20 |
+
# TODO: closed_loop not work, to fix it
|
21 |
+
def uniform(
|
22 |
+
step: int = ...,
|
23 |
+
num_steps: Optional[int] = None,
|
24 |
+
num_frames: int = ...,
|
25 |
+
context_size: Optional[int] = None,
|
26 |
+
context_stride: int = 3,
|
27 |
+
context_overlap: int = 4,
|
28 |
+
closed_loop: bool = True,
|
29 |
+
):
|
30 |
+
if num_frames <= context_size:
|
31 |
+
yield list(range(num_frames))
|
32 |
+
return
|
33 |
+
|
34 |
+
context_stride = min(
|
35 |
+
context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
|
36 |
+
)
|
37 |
+
|
38 |
+
for context_step in 1 << np.arange(context_stride):
|
39 |
+
pad = int(round(num_frames * ordered_halving(step)))
|
40 |
+
for j in range(
|
41 |
+
int(ordered_halving(step) * context_step) + pad,
|
42 |
+
num_frames + pad + (0 if closed_loop else -context_overlap),
|
43 |
+
(context_size * context_step - context_overlap),
|
44 |
+
):
|
45 |
+
yield [
|
46 |
+
e % num_frames
|
47 |
+
for e in range(j, j + context_size * context_step, context_step)
|
48 |
+
]
|
49 |
+
|
50 |
+
|
51 |
+
def uniform_v2(
|
52 |
+
step: int = ...,
|
53 |
+
num_steps: Optional[int] = None,
|
54 |
+
num_frames: int = ...,
|
55 |
+
context_size: Optional[int] = None,
|
56 |
+
context_stride: int = 3,
|
57 |
+
context_overlap: int = 4,
|
58 |
+
closed_loop: bool = True,
|
59 |
+
):
|
60 |
+
return generate_sample_idxs(
|
61 |
+
total=num_frames,
|
62 |
+
window_size=context_size,
|
63 |
+
step=context_size - context_overlap,
|
64 |
+
sample_rate=1,
|
65 |
+
drop_last=False,
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
def get_context_scheduler(name: str) -> Callable:
|
70 |
+
if name == "uniform":
|
71 |
+
return uniform
|
72 |
+
elif name == "uniform_v2":
|
73 |
+
return uniform_v2
|
74 |
+
else:
|
75 |
+
raise ValueError(f"Unknown context_overlap policy {name}")
|
76 |
+
|
77 |
+
|
78 |
+
def get_total_steps(
|
79 |
+
scheduler,
|
80 |
+
timesteps: List[int],
|
81 |
+
num_steps: Optional[int] = None,
|
82 |
+
num_frames: int = ...,
|
83 |
+
context_size: Optional[int] = None,
|
84 |
+
context_stride: int = 3,
|
85 |
+
context_overlap: int = 4,
|
86 |
+
closed_loop: bool = True,
|
87 |
+
):
|
88 |
+
return sum(
|
89 |
+
len(
|
90 |
+
list(
|
91 |
+
scheduler(
|
92 |
+
i,
|
93 |
+
num_steps,
|
94 |
+
num_frames,
|
95 |
+
context_size,
|
96 |
+
context_stride,
|
97 |
+
context_overlap,
|
98 |
+
)
|
99 |
+
)
|
100 |
+
)
|
101 |
+
for i in range(len(timesteps))
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def drop_last_repeat_context(contexts: List[List[int]]) -> List[List[int]]:
|
106 |
+
"""if len(contexts)>=2 and the max value the oenultimate list same as of the last list
|
107 |
+
|
108 |
+
Args:
|
109 |
+
List (_type_): _description_
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
List[List[int]]: _description_
|
113 |
+
"""
|
114 |
+
if len(contexts) >= 2 and contexts[-1][-1] == contexts[-2][-1]:
|
115 |
+
return contexts[:-1]
|
116 |
+
else:
|
117 |
+
return contexts
|
118 |
+
|
119 |
+
|
120 |
+
def prepare_global_context(
|
121 |
+
context_schedule: str,
|
122 |
+
num_inference_steps: int,
|
123 |
+
time_size: int,
|
124 |
+
context_frames: int,
|
125 |
+
context_stride: int,
|
126 |
+
context_overlap: int,
|
127 |
+
context_batch_size: int,
|
128 |
+
):
|
129 |
+
context_scheduler = get_context_scheduler(context_schedule)
|
130 |
+
context_queue = list(
|
131 |
+
context_scheduler(
|
132 |
+
step=0,
|
133 |
+
num_steps=num_inference_steps,
|
134 |
+
num_frames=time_size,
|
135 |
+
context_size=context_frames,
|
136 |
+
context_stride=context_stride,
|
137 |
+
context_overlap=context_overlap,
|
138 |
+
)
|
139 |
+
)
|
140 |
+
# 如果context_queue的最后一个索引最大值和倒数第二个索引最大值相同,说明最后一个列表就是因为step带来的冗余项,可以去掉
|
141 |
+
# remove the last context if max index of the last context is the same as the max index of the second last context
|
142 |
+
context_queue = drop_last_repeat_context(context_queue)
|
143 |
+
num_context_batches = math.ceil(len(context_queue) / context_batch_size)
|
144 |
+
global_context = []
|
145 |
+
for i_tmp in range(num_context_batches):
|
146 |
+
global_context.append(
|
147 |
+
context_queue[i_tmp * context_batch_size : (i_tmp + 1) * context_batch_size]
|
148 |
+
)
|
149 |
+
return global_context
|
musev/pipelines/pipeline_controlnet.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
musev/pipelines/pipeline_controlnet_predictor.py
ADDED
@@ -0,0 +1,1290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pformat, pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.autoencoder_kl import AutoencoderKL
|
31 |
+
|
32 |
+
from diffusers.models.modeling_utils import load_state_dict
|
33 |
+
from diffusers.utils import (
|
34 |
+
logging,
|
35 |
+
BaseOutput,
|
36 |
+
logging,
|
37 |
+
)
|
38 |
+
from diffusers.utils.dummy_pt_objects import ConsistencyDecoderVAE
|
39 |
+
from diffusers.utils.import_utils import is_xformers_available
|
40 |
+
|
41 |
+
from mmcm.utils.seed_util import set_all_seed
|
42 |
+
from mmcm.vision.data.video_dataset import DecordVideoDataset
|
43 |
+
from mmcm.vision.process.correct_color import hist_match_video_bcthw
|
44 |
+
from mmcm.vision.process.image_process import (
|
45 |
+
batch_dynamic_crop_resize_images,
|
46 |
+
batch_dynamic_crop_resize_images_v2,
|
47 |
+
)
|
48 |
+
from mmcm.vision.utils.data_type_util import is_video
|
49 |
+
from mmcm.vision.feature_extractor.controlnet import load_controlnet_model
|
50 |
+
|
51 |
+
from ..schedulers import (
|
52 |
+
EulerDiscreteScheduler,
|
53 |
+
LCMScheduler,
|
54 |
+
DDIMScheduler,
|
55 |
+
DDPMScheduler,
|
56 |
+
)
|
57 |
+
from ..models.unet_3d_condition import UNet3DConditionModel
|
58 |
+
from .pipeline_controlnet import (
|
59 |
+
MusevControlNetPipeline,
|
60 |
+
VideoPipelineOutput as PipelineVideoPipelineOutput,
|
61 |
+
)
|
62 |
+
from ..utils.util import save_videos_grid_with_opencv
|
63 |
+
from ..utils.model_util import (
|
64 |
+
update_pipeline_basemodel,
|
65 |
+
update_pipeline_lora_model,
|
66 |
+
update_pipeline_lora_models,
|
67 |
+
update_pipeline_model_parameters,
|
68 |
+
)
|
69 |
+
|
70 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
71 |
+
|
72 |
+
|
73 |
+
@dataclass
|
74 |
+
class VideoPipelineOutput(BaseOutput):
|
75 |
+
videos: Union[torch.Tensor, np.ndarray]
|
76 |
+
latents: Union[torch.Tensor, np.ndarray]
|
77 |
+
videos_mid: Union[torch.Tensor, np.ndarray]
|
78 |
+
controlnet_cond: Union[torch.Tensor, np.ndarray]
|
79 |
+
generated_videos: Union[torch.Tensor, np.ndarray]
|
80 |
+
|
81 |
+
|
82 |
+
def update_controlnet_processor_params(
|
83 |
+
src: Union[Dict, List[Dict]], dst: Union[Dict, List[Dict]]
|
84 |
+
):
|
85 |
+
"""merge dst into src"""
|
86 |
+
if isinstance(src, list) and not isinstance(dst, List):
|
87 |
+
dst = [dst] * len(src)
|
88 |
+
if isinstance(src, list) and isinstance(dst, list):
|
89 |
+
return [
|
90 |
+
update_controlnet_processor_params(src[i], dst[i]) for i in range(len(src))
|
91 |
+
]
|
92 |
+
if src is None:
|
93 |
+
dct = {}
|
94 |
+
else:
|
95 |
+
dct = copy.deepcopy(src)
|
96 |
+
if dst is None:
|
97 |
+
dst = {}
|
98 |
+
dct.update(dst)
|
99 |
+
return dct
|
100 |
+
|
101 |
+
|
102 |
+
class DiffusersPipelinePredictor(object):
|
103 |
+
"""wraper of diffusers pipeline, support generation function interface. support
|
104 |
+
1. text2video: inputs include text, image(optional), refer_image(optional)
|
105 |
+
2. video2video:
|
106 |
+
1. use controlnet to control spatial
|
107 |
+
2. or use video fuse noise to denoise
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
sd_model_path: str,
|
113 |
+
unet: nn.Module,
|
114 |
+
controlnet_name: Union[str, List[str]] = None,
|
115 |
+
controlnet: nn.Module = None,
|
116 |
+
lora_dict: Dict[str, Dict] = None,
|
117 |
+
requires_safety_checker: bool = False,
|
118 |
+
device: str = "cuda",
|
119 |
+
dtype: torch.dtype = torch.float16,
|
120 |
+
# controlnet parameters start
|
121 |
+
need_controlnet_processor: bool = True,
|
122 |
+
need_controlnet: bool = True,
|
123 |
+
image_resolution: int = 512,
|
124 |
+
detect_resolution: int = 512,
|
125 |
+
include_body: bool = True,
|
126 |
+
hand_and_face: bool = None,
|
127 |
+
include_face: bool = False,
|
128 |
+
include_hand: bool = True,
|
129 |
+
negative_embedding: List = None,
|
130 |
+
# controlnet parameters end
|
131 |
+
enable_xformers_memory_efficient_attention: bool = True,
|
132 |
+
lcm_lora_dct: Dict = None,
|
133 |
+
referencenet: nn.Module = None,
|
134 |
+
ip_adapter_image_proj: nn.Module = None,
|
135 |
+
vision_clip_extractor: nn.Module = None,
|
136 |
+
face_emb_extractor: nn.Module = None,
|
137 |
+
facein_image_proj: nn.Module = None,
|
138 |
+
ip_adapter_face_emb_extractor: nn.Module = None,
|
139 |
+
ip_adapter_face_image_proj: nn.Module = None,
|
140 |
+
vae_model: Optional[Tuple[nn.Module, str]] = None,
|
141 |
+
pose_guider: Optional[nn.Module] = None,
|
142 |
+
enable_zero_snr: bool = False,
|
143 |
+
) -> None:
|
144 |
+
self.sd_model_path = sd_model_path
|
145 |
+
self.unet = unet
|
146 |
+
self.controlnet_name = controlnet_name
|
147 |
+
self.controlnet = controlnet
|
148 |
+
self.requires_safety_checker = requires_safety_checker
|
149 |
+
self.device = device
|
150 |
+
self.dtype = dtype
|
151 |
+
self.need_controlnet_processor = need_controlnet_processor
|
152 |
+
self.need_controlnet = need_controlnet
|
153 |
+
self.need_controlnet_processor = need_controlnet_processor
|
154 |
+
self.image_resolution = image_resolution
|
155 |
+
self.detect_resolution = detect_resolution
|
156 |
+
self.include_body = include_body
|
157 |
+
self.hand_and_face = hand_and_face
|
158 |
+
self.include_face = include_face
|
159 |
+
self.include_hand = include_hand
|
160 |
+
self.negative_embedding = negative_embedding
|
161 |
+
self.device = device
|
162 |
+
self.dtype = dtype
|
163 |
+
self.lcm_lora_dct = lcm_lora_dct
|
164 |
+
if controlnet is None and controlnet_name is not None:
|
165 |
+
controlnet, controlnet_processor, processor_params = load_controlnet_model(
|
166 |
+
controlnet_name,
|
167 |
+
device=device,
|
168 |
+
dtype=dtype,
|
169 |
+
need_controlnet_processor=need_controlnet_processor,
|
170 |
+
need_controlnet=need_controlnet,
|
171 |
+
image_resolution=image_resolution,
|
172 |
+
detect_resolution=detect_resolution,
|
173 |
+
include_body=include_body,
|
174 |
+
include_face=include_face,
|
175 |
+
hand_and_face=hand_and_face,
|
176 |
+
include_hand=include_hand,
|
177 |
+
)
|
178 |
+
self.controlnet_processor = controlnet_processor
|
179 |
+
self.controlnet_processor_params = processor_params
|
180 |
+
logger.debug(f"init controlnet controlnet_name={controlnet_name}")
|
181 |
+
|
182 |
+
if controlnet is not None:
|
183 |
+
controlnet = controlnet.to(device=device, dtype=dtype)
|
184 |
+
controlnet.eval()
|
185 |
+
if pose_guider is not None:
|
186 |
+
pose_guider = pose_guider.to(device=device, dtype=dtype)
|
187 |
+
pose_guider.eval()
|
188 |
+
unet.to(device=device, dtype=dtype)
|
189 |
+
unet.eval()
|
190 |
+
if referencenet is not None:
|
191 |
+
referencenet.to(device=device, dtype=dtype)
|
192 |
+
referencenet.eval()
|
193 |
+
if ip_adapter_image_proj is not None:
|
194 |
+
ip_adapter_image_proj.to(device=device, dtype=dtype)
|
195 |
+
ip_adapter_image_proj.eval()
|
196 |
+
if vision_clip_extractor is not None:
|
197 |
+
vision_clip_extractor.to(device=device, dtype=dtype)
|
198 |
+
vision_clip_extractor.eval()
|
199 |
+
if face_emb_extractor is not None:
|
200 |
+
face_emb_extractor.to(device=device, dtype=dtype)
|
201 |
+
face_emb_extractor.eval()
|
202 |
+
if facein_image_proj is not None:
|
203 |
+
facein_image_proj.to(device=device, dtype=dtype)
|
204 |
+
facein_image_proj.eval()
|
205 |
+
|
206 |
+
if isinstance(vae_model, str):
|
207 |
+
# TODO: poor implementation, to improve
|
208 |
+
if "consistency" in vae_model:
|
209 |
+
vae = ConsistencyDecoderVAE.from_pretrained(vae_model)
|
210 |
+
else:
|
211 |
+
vae = AutoencoderKL.from_pretrained(vae_model)
|
212 |
+
elif isinstance(vae_model, nn.Module):
|
213 |
+
vae = vae_model
|
214 |
+
else:
|
215 |
+
vae = None
|
216 |
+
if vae is not None:
|
217 |
+
vae.to(device=device, dtype=dtype)
|
218 |
+
vae.eval()
|
219 |
+
if ip_adapter_face_emb_extractor is not None:
|
220 |
+
ip_adapter_face_emb_extractor.to(device=device, dtype=dtype)
|
221 |
+
ip_adapter_face_emb_extractor.eval()
|
222 |
+
if ip_adapter_face_image_proj is not None:
|
223 |
+
ip_adapter_face_image_proj.to(device=device, dtype=dtype)
|
224 |
+
ip_adapter_face_image_proj.eval()
|
225 |
+
params = {
|
226 |
+
"pretrained_model_name_or_path": sd_model_path,
|
227 |
+
"controlnet": controlnet,
|
228 |
+
"unet": unet,
|
229 |
+
"requires_safety_checker": requires_safety_checker,
|
230 |
+
"torch_dtype": dtype,
|
231 |
+
"torch_device": device,
|
232 |
+
"referencenet": referencenet,
|
233 |
+
"ip_adapter_image_proj": ip_adapter_image_proj,
|
234 |
+
"vision_clip_extractor": vision_clip_extractor,
|
235 |
+
"facein_image_proj": facein_image_proj,
|
236 |
+
"face_emb_extractor": face_emb_extractor,
|
237 |
+
"ip_adapter_face_emb_extractor": ip_adapter_face_emb_extractor,
|
238 |
+
"ip_adapter_face_image_proj": ip_adapter_face_image_proj,
|
239 |
+
"pose_guider": pose_guider,
|
240 |
+
}
|
241 |
+
if vae is not None:
|
242 |
+
params["vae"] = vae
|
243 |
+
pipeline = MusevControlNetPipeline.from_pretrained(**params)
|
244 |
+
pipeline = pipeline.to(torch_device=device, torch_dtype=dtype)
|
245 |
+
logger.debug(
|
246 |
+
f"init pipeline from sd_model_path={sd_model_path}, device={device}, dtype={dtype}"
|
247 |
+
)
|
248 |
+
if (
|
249 |
+
negative_embedding is not None
|
250 |
+
and pipeline.text_encoder is not None
|
251 |
+
and pipeline.tokenizer is not None
|
252 |
+
):
|
253 |
+
for neg_emb_path, neg_token in negative_embedding:
|
254 |
+
pipeline.load_textual_inversion(neg_emb_path, token=neg_token)
|
255 |
+
|
256 |
+
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
257 |
+
# pipe.enable_model_cpu_offload()
|
258 |
+
if not enable_zero_snr:
|
259 |
+
pipeline.scheduler = EulerDiscreteScheduler.from_config(
|
260 |
+
pipeline.scheduler.config
|
261 |
+
)
|
262 |
+
# pipeline.scheduler = DDIMScheduler.from_config(
|
263 |
+
# pipeline.scheduler.config,
|
264 |
+
# 该部分会影响生成视频的亮度,不适用于首���给定的视频生成
|
265 |
+
# this part will change brightness of video, not suitable for image2video mode
|
266 |
+
# rescale_betas_zero_snr affect the brightness of the generated video, not suitable for vision condition images mode
|
267 |
+
# # rescale_betas_zero_snr=True,
|
268 |
+
# )
|
269 |
+
# pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
|
270 |
+
else:
|
271 |
+
# moore scheduler, just for codetest
|
272 |
+
pipeline.scheduler = DDIMScheduler(
|
273 |
+
beta_start=0.00085,
|
274 |
+
beta_end=0.012,
|
275 |
+
beta_schedule="linear",
|
276 |
+
clip_sample=False,
|
277 |
+
steps_offset=1,
|
278 |
+
### Zero-SNR params
|
279 |
+
prediction_type="v_prediction",
|
280 |
+
rescale_betas_zero_snr=True,
|
281 |
+
timestep_spacing="trailing",
|
282 |
+
)
|
283 |
+
|
284 |
+
pipeline.enable_vae_slicing()
|
285 |
+
self.enable_xformers_memory_efficient_attention = (
|
286 |
+
enable_xformers_memory_efficient_attention
|
287 |
+
)
|
288 |
+
if enable_xformers_memory_efficient_attention:
|
289 |
+
if is_xformers_available():
|
290 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
291 |
+
else:
|
292 |
+
raise ValueError(
|
293 |
+
"xformers is not available. Make sure it is installed correctly"
|
294 |
+
)
|
295 |
+
self.pipeline = pipeline
|
296 |
+
self.unload_dict = [] # keep lora state
|
297 |
+
if lora_dict is not None:
|
298 |
+
self.load_lora(lora_dict=lora_dict)
|
299 |
+
logger.debug("load lora {}".format(" ".join(list(lora_dict.keys()))))
|
300 |
+
|
301 |
+
if lcm_lora_dct is not None:
|
302 |
+
self.pipeline.scheduler = LCMScheduler.from_config(
|
303 |
+
self.pipeline.scheduler.config
|
304 |
+
)
|
305 |
+
self.load_lora(lora_dict=lcm_lora_dct)
|
306 |
+
logger.debug("load lcm lora {}".format(" ".join(list(lcm_lora_dct.keys()))))
|
307 |
+
|
308 |
+
# logger.debug("Unet3Model Parameters")
|
309 |
+
# logger.debug(pformat(self.__dict__))
|
310 |
+
|
311 |
+
def load_lora(
|
312 |
+
self,
|
313 |
+
lora_dict: Dict[str, Dict],
|
314 |
+
):
|
315 |
+
self.pipeline, unload_dict = update_pipeline_lora_models(
|
316 |
+
self.pipeline, lora_dict, device=self.device
|
317 |
+
)
|
318 |
+
self.unload_dict += unload_dict
|
319 |
+
|
320 |
+
def unload_lora(self):
|
321 |
+
for layer_data in self.unload_dict:
|
322 |
+
layer = layer_data["layer"]
|
323 |
+
added_weight = layer_data["added_weight"]
|
324 |
+
layer.weight.data -= added_weight
|
325 |
+
self.unload_dict = []
|
326 |
+
gc.collect()
|
327 |
+
torch.cuda.empty_cache()
|
328 |
+
|
329 |
+
def update_unet(self, unet: nn.Module):
|
330 |
+
self.pipeline.unet = unet.to(device=self.device, dtype=self.dtype)
|
331 |
+
|
332 |
+
def update_sd_model(self, model_path: str, text_model_path: str):
|
333 |
+
self.pipeline = update_pipeline_basemodel(
|
334 |
+
self.pipeline,
|
335 |
+
model_path,
|
336 |
+
text_sd_model_path=text_model_path,
|
337 |
+
device=self.device,
|
338 |
+
)
|
339 |
+
|
340 |
+
def update_sd_model_and_unet(
|
341 |
+
self, lora_sd_path: str, lora_path: str, sd_model_path: str = None
|
342 |
+
):
|
343 |
+
self.pipeline = update_pipeline_model_parameters(
|
344 |
+
self.pipeline,
|
345 |
+
model_path=lora_sd_path,
|
346 |
+
lora_path=lora_path,
|
347 |
+
text_model_path=sd_model_path,
|
348 |
+
device=self.device,
|
349 |
+
)
|
350 |
+
|
351 |
+
def update_controlnet(self, controlnet_name=Union[str, List[str]]):
|
352 |
+
self.pipeline.controlnet = load_controlnet_model(controlnet_name).to(
|
353 |
+
device=self.device, dtype=self.dtype
|
354 |
+
)
|
355 |
+
|
356 |
+
def run_pipe_text2video(
|
357 |
+
self,
|
358 |
+
video_length: int,
|
359 |
+
prompt: Union[str, List[str]] = None,
|
360 |
+
# b c t h w
|
361 |
+
height: Optional[int] = None,
|
362 |
+
width: Optional[int] = None,
|
363 |
+
video_num_inference_steps: int = 50,
|
364 |
+
video_guidance_scale: float = 7.5,
|
365 |
+
video_guidance_scale_end: float = 3.5,
|
366 |
+
video_guidance_scale_method: str = "linear",
|
367 |
+
strength: float = 0.8,
|
368 |
+
video_negative_prompt: Optional[Union[str, List[str]]] = None,
|
369 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
370 |
+
num_videos_per_prompt: Optional[int] = 1,
|
371 |
+
eta: float = 0.0,
|
372 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
373 |
+
same_seed: Optional[Union[int, List[int]]] = None,
|
374 |
+
# b c t(1) ho wo
|
375 |
+
condition_latents: Optional[torch.FloatTensor] = None,
|
376 |
+
latents: Optional[torch.FloatTensor] = None,
|
377 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
378 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
379 |
+
guidance_scale: float = 7.5,
|
380 |
+
num_inference_steps: int = 50,
|
381 |
+
output_type: Optional[str] = "tensor",
|
382 |
+
return_dict: bool = True,
|
383 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
384 |
+
callback_steps: int = 1,
|
385 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
386 |
+
need_middle_latents: bool = False,
|
387 |
+
w_ind_noise: float = 0.5,
|
388 |
+
initial_common_latent: Optional[torch.FloatTensor] = None,
|
389 |
+
latent_index: torch.LongTensor = None,
|
390 |
+
vision_condition_latent_index: torch.LongTensor = None,
|
391 |
+
n_vision_condition: int = 1,
|
392 |
+
noise_type: str = "random",
|
393 |
+
max_batch_num: int = 30,
|
394 |
+
need_img_based_video_noise: bool = False,
|
395 |
+
condition_images: torch.Tensor = None,
|
396 |
+
fix_condition_images: bool = False,
|
397 |
+
redraw_condition_image: bool = False,
|
398 |
+
img_weight: float = 1e-3,
|
399 |
+
motion_speed: float = 8.0,
|
400 |
+
need_hist_match: bool = False,
|
401 |
+
refer_image: Optional[
|
402 |
+
Tuple[np.ndarray, torch.Tensor, List[str], List[np.ndarray]]
|
403 |
+
] = None,
|
404 |
+
ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None,
|
405 |
+
fixed_refer_image: bool = True,
|
406 |
+
fixed_ip_adapter_image: bool = True,
|
407 |
+
redraw_condition_image_with_ipdapter: bool = True,
|
408 |
+
redraw_condition_image_with_referencenet: bool = True,
|
409 |
+
refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None,
|
410 |
+
fixed_refer_face_image: bool = True,
|
411 |
+
redraw_condition_image_with_facein: bool = True,
|
412 |
+
ip_adapter_scale: float = 1.0,
|
413 |
+
redraw_condition_image_with_ip_adapter_face: bool = True,
|
414 |
+
facein_scale: float = 1.0,
|
415 |
+
ip_adapter_face_scale: float = 1.0,
|
416 |
+
prompt_only_use_image_prompt: bool = False,
|
417 |
+
# serial_denoise parameter start
|
418 |
+
record_mid_video_noises: bool = False,
|
419 |
+
record_mid_video_latents: bool = False,
|
420 |
+
video_overlap: int = 1,
|
421 |
+
# serial_denoise parameter end
|
422 |
+
# parallel_denoise parameter start
|
423 |
+
context_schedule="uniform",
|
424 |
+
context_frames=12,
|
425 |
+
context_stride=1,
|
426 |
+
context_overlap=4,
|
427 |
+
context_batch_size=1,
|
428 |
+
interpolation_factor=1,
|
429 |
+
# parallel_denoise parameter end
|
430 |
+
):
|
431 |
+
"""
|
432 |
+
generate long video with end2end mode
|
433 |
+
1. prepare vision condition image by assingning, redraw, or generation with text2image module with skip_temporal_layer=True;
|
434 |
+
2. use image or latest of vision condition image to generate first shot;
|
435 |
+
3. use last n (1) image or last latent of last shot as new vision condition latent to generate next shot
|
436 |
+
4. repeat n_batch times between 2 and 3
|
437 |
+
|
438 |
+
类似img2img pipeline
|
439 |
+
refer_image和ip_adapter_image的来源:
|
440 |
+
1. 输入给定;
|
441 |
+
2. 当未输入时,纯text2video生成首帧,并赋值更新refer_image和ip_adapter_image;
|
442 |
+
3. 当有输入,但是因为redraw更新了首帧时,也需要赋值更新refer_image和ip_adapter_image;
|
443 |
+
|
444 |
+
refer_image和ip_adapter_image的作用:
|
445 |
+
1. 当无首帧图像时,用于生成首帧;
|
446 |
+
2. 用于生成视频。
|
447 |
+
|
448 |
+
|
449 |
+
similar to diffusers img2img pipeline.
|
450 |
+
three ways to prepare refer_image and ip_adapter_image
|
451 |
+
1. from input parameter
|
452 |
+
2. when input paramter is None, use text2video to generate vis cond image, and use as refer_image and ip_adapter_image too.
|
453 |
+
3. given from input paramter, but still redraw, update with redrawn vis cond image.
|
454 |
+
"""
|
455 |
+
# crop resize images
|
456 |
+
if condition_images is not None:
|
457 |
+
logger.debug(
|
458 |
+
f"center crop resize condition_images={condition_images.shape}, to height={height}, width={width}"
|
459 |
+
)
|
460 |
+
condition_images = batch_dynamic_crop_resize_images_v2(
|
461 |
+
condition_images,
|
462 |
+
target_height=height,
|
463 |
+
target_width=width,
|
464 |
+
)
|
465 |
+
if refer_image is not None:
|
466 |
+
logger.debug(
|
467 |
+
f"center crop resize refer_image to height={height}, width={width}"
|
468 |
+
)
|
469 |
+
refer_image = batch_dynamic_crop_resize_images_v2(
|
470 |
+
refer_image,
|
471 |
+
target_height=height,
|
472 |
+
target_width=width,
|
473 |
+
)
|
474 |
+
if ip_adapter_image is not None:
|
475 |
+
logger.debug(
|
476 |
+
f"center crop resize ip_adapter_image to height={height}, width={width}"
|
477 |
+
)
|
478 |
+
ip_adapter_image = batch_dynamic_crop_resize_images_v2(
|
479 |
+
ip_adapter_image,
|
480 |
+
target_height=height,
|
481 |
+
target_width=width,
|
482 |
+
)
|
483 |
+
if refer_face_image is not None:
|
484 |
+
logger.debug(
|
485 |
+
f"center crop resize refer_face_image to height={height}, width={width}"
|
486 |
+
)
|
487 |
+
refer_face_image = batch_dynamic_crop_resize_images_v2(
|
488 |
+
refer_face_image,
|
489 |
+
target_height=height,
|
490 |
+
target_width=width,
|
491 |
+
)
|
492 |
+
run_video_length = video_length
|
493 |
+
# generate vision condition frame start
|
494 |
+
# if condition_images is None, generate with refer_image, ip_adapter_image
|
495 |
+
# if condition_images not None and need redraw, according to redraw_condition_image_with_ipdapter, redraw_condition_image_with_referencenet, refer_image, ip_adapter_image
|
496 |
+
if n_vision_condition > 0:
|
497 |
+
if condition_images is None and condition_latents is None:
|
498 |
+
logger.debug("run_pipe_text2video, generate first_image")
|
499 |
+
(
|
500 |
+
condition_images,
|
501 |
+
condition_latents,
|
502 |
+
_,
|
503 |
+
_,
|
504 |
+
_,
|
505 |
+
) = self.pipeline(
|
506 |
+
prompt=prompt,
|
507 |
+
num_inference_steps=num_inference_steps,
|
508 |
+
guidance_scale=guidance_scale,
|
509 |
+
negative_prompt=negative_prompt,
|
510 |
+
video_length=1,
|
511 |
+
height=height,
|
512 |
+
width=width,
|
513 |
+
return_dict=False,
|
514 |
+
skip_temporal_layer=True,
|
515 |
+
output_type="np",
|
516 |
+
generator=generator,
|
517 |
+
w_ind_noise=w_ind_noise,
|
518 |
+
need_img_based_video_noise=need_img_based_video_noise,
|
519 |
+
refer_image=refer_image
|
520 |
+
if redraw_condition_image_with_referencenet
|
521 |
+
else None,
|
522 |
+
ip_adapter_image=ip_adapter_image
|
523 |
+
if redraw_condition_image_with_ipdapter
|
524 |
+
else None,
|
525 |
+
refer_face_image=refer_face_image
|
526 |
+
if redraw_condition_image_with_facein
|
527 |
+
else None,
|
528 |
+
ip_adapter_scale=ip_adapter_scale,
|
529 |
+
facein_scale=facein_scale,
|
530 |
+
ip_adapter_face_scale=ip_adapter_face_scale,
|
531 |
+
ip_adapter_face_image=refer_face_image
|
532 |
+
if redraw_condition_image_with_ip_adapter_face
|
533 |
+
else None,
|
534 |
+
prompt_only_use_image_prompt=prompt_only_use_image_prompt,
|
535 |
+
)
|
536 |
+
run_video_length = video_length - 1
|
537 |
+
elif (
|
538 |
+
condition_images is not None
|
539 |
+
and redraw_condition_image
|
540 |
+
and condition_latents is None
|
541 |
+
):
|
542 |
+
logger.debug("run_pipe_text2video, redraw first_image")
|
543 |
+
|
544 |
+
(
|
545 |
+
condition_images,
|
546 |
+
condition_latents,
|
547 |
+
_,
|
548 |
+
_,
|
549 |
+
_,
|
550 |
+
) = self.pipeline(
|
551 |
+
prompt=prompt,
|
552 |
+
image=condition_images,
|
553 |
+
num_inference_steps=num_inference_steps,
|
554 |
+
guidance_scale=guidance_scale,
|
555 |
+
negative_prompt=negative_prompt,
|
556 |
+
strength=strength,
|
557 |
+
video_length=condition_images.shape[2],
|
558 |
+
height=height,
|
559 |
+
width=width,
|
560 |
+
return_dict=False,
|
561 |
+
skip_temporal_layer=True,
|
562 |
+
output_type="np",
|
563 |
+
generator=generator,
|
564 |
+
w_ind_noise=w_ind_noise,
|
565 |
+
need_img_based_video_noise=need_img_based_video_noise,
|
566 |
+
refer_image=refer_image
|
567 |
+
if redraw_condition_image_with_referencenet
|
568 |
+
else None,
|
569 |
+
ip_adapter_image=ip_adapter_image
|
570 |
+
if redraw_condition_image_with_ipdapter
|
571 |
+
else None,
|
572 |
+
refer_face_image=refer_face_image
|
573 |
+
if redraw_condition_image_with_facein
|
574 |
+
else None,
|
575 |
+
ip_adapter_scale=ip_adapter_scale,
|
576 |
+
facein_scale=facein_scale,
|
577 |
+
ip_adapter_face_scale=ip_adapter_face_scale,
|
578 |
+
ip_adapter_face_image=refer_face_image
|
579 |
+
if redraw_condition_image_with_ip_adapter_face
|
580 |
+
else None,
|
581 |
+
prompt_only_use_image_prompt=prompt_only_use_image_prompt,
|
582 |
+
)
|
583 |
+
else:
|
584 |
+
condition_images = None
|
585 |
+
condition_latents = None
|
586 |
+
# generate vision condition frame end
|
587 |
+
|
588 |
+
# refer_image and ip_adapter_image, update mode from 2 and 3 as mentioned above start
|
589 |
+
if (
|
590 |
+
refer_image is not None
|
591 |
+
and redraw_condition_image
|
592 |
+
and condition_images is not None
|
593 |
+
):
|
594 |
+
refer_image = condition_images * 255.0
|
595 |
+
logger.debug(f"update refer_image because of redraw_condition_image")
|
596 |
+
elif (
|
597 |
+
refer_image is None
|
598 |
+
and self.pipeline.referencenet is not None
|
599 |
+
and condition_images is not None
|
600 |
+
):
|
601 |
+
refer_image = condition_images * 255.0
|
602 |
+
logger.debug(f"update refer_image because of generate first_image")
|
603 |
+
|
604 |
+
# ipadapter_image
|
605 |
+
if (
|
606 |
+
ip_adapter_image is not None
|
607 |
+
and redraw_condition_image
|
608 |
+
and condition_images is not None
|
609 |
+
):
|
610 |
+
ip_adapter_image = condition_images * 255.0
|
611 |
+
logger.debug(f"update ip_adapter_image because of redraw_condition_image")
|
612 |
+
elif (
|
613 |
+
ip_adapter_image is None
|
614 |
+
and self.pipeline.ip_adapter_image_proj is not None
|
615 |
+
and condition_images is not None
|
616 |
+
):
|
617 |
+
ip_adapter_image = condition_images * 255.0
|
618 |
+
logger.debug(f"update ip_adapter_image because of generate first_image")
|
619 |
+
# refer_image and ip_adapter_image, update mode from 2 and 3 as mentioned above end
|
620 |
+
|
621 |
+
# refer_face_image, update mode from 2 and 3 as mentioned above start
|
622 |
+
if (
|
623 |
+
refer_face_image is not None
|
624 |
+
and redraw_condition_image
|
625 |
+
and condition_images is not None
|
626 |
+
):
|
627 |
+
refer_face_image = condition_images * 255.0
|
628 |
+
logger.debug(f"update refer_face_image because of redraw_condition_image")
|
629 |
+
elif (
|
630 |
+
refer_face_image is None
|
631 |
+
and self.pipeline.facein_image_proj is not None
|
632 |
+
and condition_images is not None
|
633 |
+
):
|
634 |
+
refer_face_image = condition_images * 255.0
|
635 |
+
logger.debug(f"update face_image because of generate first_image")
|
636 |
+
# refer_face_image, update mode from 2 and 3 as mentioned above end
|
637 |
+
|
638 |
+
last_mid_video_noises = None
|
639 |
+
last_mid_video_latents = None
|
640 |
+
initial_common_latent = None
|
641 |
+
|
642 |
+
out_videos = []
|
643 |
+
for i_batch in range(max_batch_num):
|
644 |
+
logger.debug(f"sd_pipeline_predictor, run_pipe_text2video: {i_batch}")
|
645 |
+
if max_batch_num is not None and i_batch == max_batch_num:
|
646 |
+
break
|
647 |
+
|
648 |
+
if i_batch == 0:
|
649 |
+
result_overlap = 0
|
650 |
+
else:
|
651 |
+
if n_vision_condition > 0:
|
652 |
+
# ignore condition_images if condition_latents is not None in pipeline
|
653 |
+
if not fix_condition_images:
|
654 |
+
logger.debug(f"{i_batch}, update condition_latents")
|
655 |
+
condition_latents = out_latents_batch[
|
656 |
+
:, :, -n_vision_condition:, :, :
|
657 |
+
]
|
658 |
+
else:
|
659 |
+
logger.debug(f"{i_batch}, do not update condition_latents")
|
660 |
+
result_overlap = n_vision_condition
|
661 |
+
|
662 |
+
if not fixed_refer_image and n_vision_condition > 0:
|
663 |
+
logger.debug("ref_image use last frame of last generated out video")
|
664 |
+
refer_image = out_batch[:, :, -n_vision_condition:, :, :] * 255.0
|
665 |
+
else:
|
666 |
+
logger.debug("use given fixed ref_image")
|
667 |
+
|
668 |
+
if not fixed_ip_adapter_image and n_vision_condition > 0:
|
669 |
+
logger.debug(
|
670 |
+
"ip_adapter_image use last frame of last generated out video"
|
671 |
+
)
|
672 |
+
ip_adapter_image = (
|
673 |
+
out_batch[:, :, -n_vision_condition:, :, :] * 255.0
|
674 |
+
)
|
675 |
+
else:
|
676 |
+
logger.debug("use given fixed ip_adapter_image")
|
677 |
+
|
678 |
+
if not fixed_refer_face_image and n_vision_condition > 0:
|
679 |
+
logger.debug(
|
680 |
+
"refer_face_image use last frame of last generated out video"
|
681 |
+
)
|
682 |
+
refer_face_image = (
|
683 |
+
out_batch[:, :, -n_vision_condition:, :, :] * 255.0
|
684 |
+
)
|
685 |
+
else:
|
686 |
+
logger.debug("use given fixed ip_adapter_image")
|
687 |
+
|
688 |
+
run_video_length = video_length
|
689 |
+
if same_seed is not None:
|
690 |
+
_, generator = set_all_seed(same_seed)
|
691 |
+
|
692 |
+
out = self.pipeline(
|
693 |
+
video_length=run_video_length, # int
|
694 |
+
prompt=prompt,
|
695 |
+
num_inference_steps=video_num_inference_steps,
|
696 |
+
height=height,
|
697 |
+
width=width,
|
698 |
+
generator=generator,
|
699 |
+
condition_images=condition_images,
|
700 |
+
condition_latents=condition_latents, # b co t(1) ho wo
|
701 |
+
skip_temporal_layer=False,
|
702 |
+
output_type="np",
|
703 |
+
noise_type=noise_type,
|
704 |
+
negative_prompt=video_negative_prompt,
|
705 |
+
guidance_scale=video_guidance_scale,
|
706 |
+
guidance_scale_end=video_guidance_scale_end,
|
707 |
+
guidance_scale_method=video_guidance_scale_method,
|
708 |
+
w_ind_noise=w_ind_noise,
|
709 |
+
need_img_based_video_noise=need_img_based_video_noise,
|
710 |
+
img_weight=img_weight,
|
711 |
+
motion_speed=motion_speed,
|
712 |
+
vision_condition_latent_index=vision_condition_latent_index,
|
713 |
+
refer_image=refer_image,
|
714 |
+
ip_adapter_image=ip_adapter_image,
|
715 |
+
refer_face_image=refer_face_image,
|
716 |
+
ip_adapter_scale=ip_adapter_scale,
|
717 |
+
facein_scale=facein_scale,
|
718 |
+
ip_adapter_face_scale=ip_adapter_face_scale,
|
719 |
+
ip_adapter_face_image=refer_face_image,
|
720 |
+
prompt_only_use_image_prompt=prompt_only_use_image_prompt,
|
721 |
+
initial_common_latent=initial_common_latent,
|
722 |
+
# serial_denoise parameter start
|
723 |
+
record_mid_video_noises=record_mid_video_noises,
|
724 |
+
last_mid_video_noises=last_mid_video_noises,
|
725 |
+
record_mid_video_latents=record_mid_video_latents,
|
726 |
+
last_mid_video_latents=last_mid_video_latents,
|
727 |
+
video_overlap=video_overlap,
|
728 |
+
# serial_denoise parameter end
|
729 |
+
# parallel_denoise parameter start
|
730 |
+
context_schedule=context_schedule,
|
731 |
+
context_frames=context_frames,
|
732 |
+
context_stride=context_stride,
|
733 |
+
context_overlap=context_overlap,
|
734 |
+
context_batch_size=context_batch_size,
|
735 |
+
interpolation_factor=interpolation_factor,
|
736 |
+
# parallel_denoise parameter end
|
737 |
+
)
|
738 |
+
logger.debug(
|
739 |
+
f"run_pipe_text2video, out.videos.shape, i_batch={i_batch}, videos={out.videos.shape}, result_overlap={result_overlap}"
|
740 |
+
)
|
741 |
+
out_batch = out.videos[:, :, result_overlap:, :, :]
|
742 |
+
out_latents_batch = out.latents[:, :, result_overlap:, :, :]
|
743 |
+
out_videos.append(out_batch)
|
744 |
+
|
745 |
+
out_videos = np.concatenate(out_videos, axis=2)
|
746 |
+
if need_hist_match:
|
747 |
+
out_videos[:, :, 1:, :, :] = hist_match_video_bcthw(
|
748 |
+
out_videos[:, :, 1:, :, :], out_videos[:, :, :1, :, :], value=255.0
|
749 |
+
)
|
750 |
+
return out_videos
|
751 |
+
|
752 |
+
def run_pipe_with_latent_input(
|
753 |
+
self,
|
754 |
+
):
|
755 |
+
pass
|
756 |
+
|
757 |
+
def run_pipe_middle2video_with_middle(self, middle: Tuple[str, Iterable]):
|
758 |
+
pass
|
759 |
+
|
760 |
+
def run_pipe_video2video(
|
761 |
+
self,
|
762 |
+
video: Tuple[str, Iterable],
|
763 |
+
time_size: int = None,
|
764 |
+
sample_rate: int = None,
|
765 |
+
overlap: int = None,
|
766 |
+
step: int = None,
|
767 |
+
prompt: Union[str, List[str]] = None,
|
768 |
+
# b c t h w
|
769 |
+
height: Optional[int] = None,
|
770 |
+
width: Optional[int] = None,
|
771 |
+
num_inference_steps: int = 50,
|
772 |
+
video_num_inference_steps: int = 50,
|
773 |
+
guidance_scale: float = 7.5,
|
774 |
+
video_guidance_scale: float = 7.5,
|
775 |
+
video_guidance_scale_end: float = 3.5,
|
776 |
+
video_guidance_scale_method: str = "linear",
|
777 |
+
video_negative_prompt: Optional[Union[str, List[str]]] = None,
|
778 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
779 |
+
num_videos_per_prompt: Optional[int] = 1,
|
780 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
781 |
+
eta: float = 0.0,
|
782 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
783 |
+
controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None,
|
784 |
+
# b c t(1) hi wi
|
785 |
+
controlnet_condition_images: Optional[torch.FloatTensor] = None,
|
786 |
+
# b c t(1) ho wo
|
787 |
+
controlnet_condition_latents: Optional[torch.FloatTensor] = None,
|
788 |
+
# b c t(1) ho wo
|
789 |
+
condition_latents: Optional[torch.FloatTensor] = None,
|
790 |
+
condition_images: Optional[torch.FloatTensor] = None,
|
791 |
+
fix_condition_images: bool = False,
|
792 |
+
latents: Optional[torch.FloatTensor] = None,
|
793 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
794 |
+
output_type: Optional[str] = "tensor",
|
795 |
+
return_dict: bool = True,
|
796 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
797 |
+
callback_steps: int = 1,
|
798 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
799 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
800 |
+
guess_mode: bool = False,
|
801 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
802 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
803 |
+
need_middle_latents: bool = False,
|
804 |
+
w_ind_noise: float = 0.5,
|
805 |
+
img_weight: float = 0.001,
|
806 |
+
initial_common_latent: Optional[torch.FloatTensor] = None,
|
807 |
+
latent_index: torch.LongTensor = None,
|
808 |
+
vision_condition_latent_index: torch.LongTensor = None,
|
809 |
+
noise_type: str = "random",
|
810 |
+
controlnet_processor_params: Dict = None,
|
811 |
+
need_return_videos: bool = False,
|
812 |
+
need_return_condition: bool = False,
|
813 |
+
max_batch_num: int = 30,
|
814 |
+
strength: float = 0.8,
|
815 |
+
video_strength: float = 0.8,
|
816 |
+
need_video2video: bool = False,
|
817 |
+
need_img_based_video_noise: bool = False,
|
818 |
+
need_hist_match: bool = False,
|
819 |
+
end_to_end: bool = True,
|
820 |
+
refer_image: Optional[
|
821 |
+
Tuple[np.ndarray, torch.Tensor, List[str], List[np.ndarray]]
|
822 |
+
] = None,
|
823 |
+
ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None,
|
824 |
+
fixed_refer_image: bool = True,
|
825 |
+
fixed_ip_adapter_image: bool = True,
|
826 |
+
redraw_condition_image: bool = False,
|
827 |
+
redraw_condition_image_with_ipdapter: bool = True,
|
828 |
+
redraw_condition_image_with_referencenet: bool = True,
|
829 |
+
refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None,
|
830 |
+
fixed_refer_face_image: bool = True,
|
831 |
+
redraw_condition_image_with_facein: bool = True,
|
832 |
+
ip_adapter_scale: float = 1.0,
|
833 |
+
facein_scale: float = 1.0,
|
834 |
+
ip_adapter_face_scale: float = 1.0,
|
835 |
+
redraw_condition_image_with_ip_adapter_face: bool = False,
|
836 |
+
n_vision_condition: int = 1,
|
837 |
+
prompt_only_use_image_prompt: bool = False,
|
838 |
+
motion_speed: float = 8.0,
|
839 |
+
# serial_denoise parameter start
|
840 |
+
record_mid_video_noises: bool = False,
|
841 |
+
record_mid_video_latents: bool = False,
|
842 |
+
video_overlap: int = 1,
|
843 |
+
# serial_denoise parameter end
|
844 |
+
# parallel_denoise parameter start
|
845 |
+
context_schedule="uniform",
|
846 |
+
context_frames=12,
|
847 |
+
context_stride=1,
|
848 |
+
context_overlap=4,
|
849 |
+
context_batch_size=1,
|
850 |
+
interpolation_factor=1,
|
851 |
+
# parallel_denoise parameter end
|
852 |
+
# 支持 video_path 时多种输入
|
853 |
+
# TODO:// video_has_condition =False,当且仅支持 video_is_middle=True, 待后续重构
|
854 |
+
# TODO:// when video_has_condition =False, video_is_middle should be True.
|
855 |
+
video_is_middle: bool = False,
|
856 |
+
video_has_condition: bool = True,
|
857 |
+
):
|
858 |
+
"""
|
859 |
+
类似controlnet text2img pipeline。 输入视频,用视频得到controlnet condition。
|
860 |
+
目前仅支持time_size == step,overlap=0
|
861 |
+
输出视频长度=输入视频长度
|
862 |
+
|
863 |
+
similar to controlnet text2image pipeline, generate video with controlnet condition from given video.
|
864 |
+
By now, sliding window only support time_size == step, overlap = 0.
|
865 |
+
"""
|
866 |
+
if isinstance(video, str):
|
867 |
+
video_reader = DecordVideoDataset(
|
868 |
+
video,
|
869 |
+
time_size=time_size,
|
870 |
+
step=step,
|
871 |
+
overlap=overlap,
|
872 |
+
sample_rate=sample_rate,
|
873 |
+
device="cpu",
|
874 |
+
data_type="rgb",
|
875 |
+
channels_order="c t h w",
|
876 |
+
drop_last=True,
|
877 |
+
)
|
878 |
+
else:
|
879 |
+
video_reader = video
|
880 |
+
videos = [] if need_return_videos else None
|
881 |
+
out_videos = []
|
882 |
+
out_condition = (
|
883 |
+
[]
|
884 |
+
if need_return_condition and self.pipeline.controlnet is not None
|
885 |
+
else None
|
886 |
+
)
|
887 |
+
# crop resize images
|
888 |
+
if condition_images is not None:
|
889 |
+
logger.debug(
|
890 |
+
f"center crop resize condition_images={condition_images.shape}, to height={height}, width={width}"
|
891 |
+
)
|
892 |
+
condition_images = batch_dynamic_crop_resize_images_v2(
|
893 |
+
condition_images,
|
894 |
+
target_height=height,
|
895 |
+
target_width=width,
|
896 |
+
)
|
897 |
+
if refer_image is not None:
|
898 |
+
logger.debug(
|
899 |
+
f"center crop resize refer_image to height={height}, width={width}"
|
900 |
+
)
|
901 |
+
refer_image = batch_dynamic_crop_resize_images_v2(
|
902 |
+
refer_image,
|
903 |
+
target_height=height,
|
904 |
+
target_width=width,
|
905 |
+
)
|
906 |
+
if ip_adapter_image is not None:
|
907 |
+
logger.debug(
|
908 |
+
f"center crop resize ip_adapter_image to height={height}, width={width}"
|
909 |
+
)
|
910 |
+
ip_adapter_image = batch_dynamic_crop_resize_images_v2(
|
911 |
+
ip_adapter_image,
|
912 |
+
target_height=height,
|
913 |
+
target_width=width,
|
914 |
+
)
|
915 |
+
if refer_face_image is not None:
|
916 |
+
logger.debug(
|
917 |
+
f"center crop resize refer_face_image to height={height}, width={width}"
|
918 |
+
)
|
919 |
+
refer_face_image = batch_dynamic_crop_resize_images_v2(
|
920 |
+
refer_face_image,
|
921 |
+
target_height=height,
|
922 |
+
target_width=width,
|
923 |
+
)
|
924 |
+
first_image = None
|
925 |
+
last_mid_video_noises = None
|
926 |
+
last_mid_video_latents = None
|
927 |
+
initial_common_latent = None
|
928 |
+
# initial_common_latent = torch.randn((1, 4, 1, 112, 64)).to(
|
929 |
+
# device=self.device, dtype=self.dtype
|
930 |
+
# )
|
931 |
+
|
932 |
+
for i_batch, item in enumerate(video_reader):
|
933 |
+
logger.debug(f"\n sd_pipeline_predictor, run_pipe_video2video: {i_batch}")
|
934 |
+
if max_batch_num is not None and i_batch == max_batch_num:
|
935 |
+
break
|
936 |
+
# read and prepare video batch
|
937 |
+
batch = item.data
|
938 |
+
batch = batch_dynamic_crop_resize_images(
|
939 |
+
batch,
|
940 |
+
target_height=height,
|
941 |
+
target_width=width,
|
942 |
+
)
|
943 |
+
|
944 |
+
batch = batch[np.newaxis, ...]
|
945 |
+
batch_size, channel, video_length, video_height, video_width = batch.shape
|
946 |
+
# extract controlnet middle
|
947 |
+
if self.pipeline.controlnet is not None:
|
948 |
+
batch = rearrange(batch, "b c t h w-> (b t) h w c")
|
949 |
+
controlnet_processor_params = update_controlnet_processor_params(
|
950 |
+
src=self.controlnet_processor_params,
|
951 |
+
dst=controlnet_processor_params,
|
952 |
+
)
|
953 |
+
if not video_is_middle:
|
954 |
+
batch_condition = self.controlnet_processor(
|
955 |
+
data=batch,
|
956 |
+
data_channel_order="b h w c",
|
957 |
+
target_height=height,
|
958 |
+
target_width=width,
|
959 |
+
return_type="np",
|
960 |
+
return_data_channel_order="b c h w",
|
961 |
+
input_rgb_order="rgb",
|
962 |
+
processor_params=controlnet_processor_params,
|
963 |
+
)
|
964 |
+
else:
|
965 |
+
# TODO: 临时用于可视化输入的 controlnet middle 序列,后续待拆到 middl2video中,也可以增加参数支持
|
966 |
+
# TODO: only use video_path is controlnet middle output, to improved
|
967 |
+
batch_condition = rearrange(
|
968 |
+
copy.deepcopy(batch), " b h w c-> b c h w"
|
969 |
+
)
|
970 |
+
|
971 |
+
# 当前仅当 输入是 middle、condition_image的pose在middle首帧之前,需要重新生成condition_images的pose并绑定到middle_batch上
|
972 |
+
# when video_path is middle seq and condition_image is not aligned with middle seq,
|
973 |
+
# regenerate codntion_images pose, and then concat into middle_batch,
|
974 |
+
if (
|
975 |
+
i_batch == 0
|
976 |
+
and not video_has_condition
|
977 |
+
and video_is_middle
|
978 |
+
and condition_images is not None
|
979 |
+
):
|
980 |
+
condition_images_reshape = rearrange(
|
981 |
+
condition_images, "b c t h w-> (b t) h w c"
|
982 |
+
)
|
983 |
+
condition_images_condition = self.controlnet_processor(
|
984 |
+
data=condition_images_reshape,
|
985 |
+
data_channel_order="b h w c",
|
986 |
+
target_height=height,
|
987 |
+
target_width=width,
|
988 |
+
return_type="np",
|
989 |
+
return_data_channel_order="b c h w",
|
990 |
+
input_rgb_order="rgb",
|
991 |
+
processor_params=controlnet_processor_params,
|
992 |
+
)
|
993 |
+
condition_images_condition = rearrange(
|
994 |
+
condition_images_condition,
|
995 |
+
"(b t) c h w-> b c t h w",
|
996 |
+
b=batch_size,
|
997 |
+
)
|
998 |
+
else:
|
999 |
+
condition_images_condition = None
|
1000 |
+
if not isinstance(batch_condition, list):
|
1001 |
+
batch_condition = rearrange(
|
1002 |
+
batch_condition, "(b t) c h w-> b c t h w", b=batch_size
|
1003 |
+
)
|
1004 |
+
if condition_images_condition is not None:
|
1005 |
+
batch_condition = np.concatenate(
|
1006 |
+
[
|
1007 |
+
condition_images_condition,
|
1008 |
+
batch_condition,
|
1009 |
+
],
|
1010 |
+
axis=2,
|
1011 |
+
)
|
1012 |
+
# 此时 batch_condition 比 batch 多了一帧,为了最终视频能 concat 存储,替换下
|
1013 |
+
# 当前仅适用于 condition_images_condition 不为None
|
1014 |
+
# when condition_images_condition is not None, batch_condition has more frames than batch
|
1015 |
+
batch = rearrange(batch_condition, "b c t h w ->(b t) h w c")
|
1016 |
+
else:
|
1017 |
+
batch_condition = [
|
1018 |
+
rearrange(x, "(b t) c h w-> b c t h w", b=batch_size)
|
1019 |
+
for x in batch_condition
|
1020 |
+
]
|
1021 |
+
if condition_images_condition is not None:
|
1022 |
+
batch_condition = [
|
1023 |
+
np.concatenate(
|
1024 |
+
[condition_images_condition, batch_condition_tmp],
|
1025 |
+
axis=2,
|
1026 |
+
)
|
1027 |
+
for batch_condition_tmp in batch_condition
|
1028 |
+
]
|
1029 |
+
batch = rearrange(batch, "(b t) h w c -> b c t h w", b=batch_size)
|
1030 |
+
else:
|
1031 |
+
batch_condition = None
|
1032 |
+
# condition [0,255]
|
1033 |
+
# latent: [0,1]
|
1034 |
+
# 按需求生成多个片段,
|
1035 |
+
# generate multi video_shot
|
1036 |
+
# 第一个片段 会特殊处理,需要生成首帧
|
1037 |
+
# first shot is special because of first frame.
|
1038 |
+
# 后续片段根据拿前一个片段结果,首尾相连的方式生成。
|
1039 |
+
# use last frame of last shot as the first frame of the current shot
|
1040 |
+
# TODO: 当前独立拆开实现,待后续融合到一起实现
|
1041 |
+
# TODO: to optimize implementation way
|
1042 |
+
if n_vision_condition == 0:
|
1043 |
+
actual_video_length = video_length
|
1044 |
+
control_image = batch_condition
|
1045 |
+
first_image_controlnet_condition = None
|
1046 |
+
first_image_latents = None
|
1047 |
+
if need_video2video:
|
1048 |
+
video = batch
|
1049 |
+
else:
|
1050 |
+
video = None
|
1051 |
+
result_overlap = 0
|
1052 |
+
else:
|
1053 |
+
if i_batch == 0:
|
1054 |
+
if self.pipeline.controlnet is not None:
|
1055 |
+
if not isinstance(batch_condition, list):
|
1056 |
+
first_image_controlnet_condition = batch_condition[
|
1057 |
+
:, :, :1, :, :
|
1058 |
+
]
|
1059 |
+
else:
|
1060 |
+
first_image_controlnet_condition = [
|
1061 |
+
x[:, :, :1, :, :] for x in batch_condition
|
1062 |
+
]
|
1063 |
+
else:
|
1064 |
+
first_image_controlnet_condition = None
|
1065 |
+
if need_video2video:
|
1066 |
+
if condition_images is None:
|
1067 |
+
video = batch[:, :, :1, :, :]
|
1068 |
+
else:
|
1069 |
+
video = condition_images
|
1070 |
+
else:
|
1071 |
+
video = None
|
1072 |
+
if condition_images is not None and not redraw_condition_image:
|
1073 |
+
first_image = condition_images
|
1074 |
+
first_image_latents = None
|
1075 |
+
else:
|
1076 |
+
(
|
1077 |
+
first_image,
|
1078 |
+
first_image_latents,
|
1079 |
+
_,
|
1080 |
+
_,
|
1081 |
+
_,
|
1082 |
+
) = self.pipeline(
|
1083 |
+
prompt=prompt,
|
1084 |
+
image=video,
|
1085 |
+
control_image=first_image_controlnet_condition,
|
1086 |
+
num_inference_steps=num_inference_steps,
|
1087 |
+
video_length=1,
|
1088 |
+
height=height,
|
1089 |
+
width=width,
|
1090 |
+
return_dict=False,
|
1091 |
+
skip_temporal_layer=True,
|
1092 |
+
output_type="np",
|
1093 |
+
generator=generator,
|
1094 |
+
negative_prompt=negative_prompt,
|
1095 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
1096 |
+
control_guidance_start=control_guidance_start,
|
1097 |
+
control_guidance_end=control_guidance_end,
|
1098 |
+
w_ind_noise=w_ind_noise,
|
1099 |
+
strength=strength,
|
1100 |
+
refer_image=refer_image
|
1101 |
+
if redraw_condition_image_with_referencenet
|
1102 |
+
else None,
|
1103 |
+
ip_adapter_image=ip_adapter_image
|
1104 |
+
if redraw_condition_image_with_ipdapter
|
1105 |
+
else None,
|
1106 |
+
refer_face_image=refer_face_image
|
1107 |
+
if redraw_condition_image_with_facein
|
1108 |
+
else None,
|
1109 |
+
ip_adapter_scale=ip_adapter_scale,
|
1110 |
+
facein_scale=facein_scale,
|
1111 |
+
ip_adapter_face_scale=ip_adapter_face_scale,
|
1112 |
+
ip_adapter_face_image=refer_face_image
|
1113 |
+
if redraw_condition_image_with_ip_adapter_face
|
1114 |
+
else None,
|
1115 |
+
prompt_only_use_image_prompt=prompt_only_use_image_prompt,
|
1116 |
+
)
|
1117 |
+
if refer_image is not None:
|
1118 |
+
refer_image = first_image * 255.0
|
1119 |
+
if ip_adapter_image is not None:
|
1120 |
+
ip_adapter_image = first_image * 255.0
|
1121 |
+
# 首帧用于后续推断可以直接用first_image_latent不需要 first_image了
|
1122 |
+
first_image = None
|
1123 |
+
if self.pipeline.controlnet is not None:
|
1124 |
+
if not isinstance(batch_condition, list):
|
1125 |
+
control_image = batch_condition[:, :, 1:, :, :]
|
1126 |
+
logger.debug(f"control_image={control_image.shape}")
|
1127 |
+
else:
|
1128 |
+
control_image = [x[:, :, 1:, :, :] for x in batch_condition]
|
1129 |
+
else:
|
1130 |
+
control_image = None
|
1131 |
+
|
1132 |
+
actual_video_length = time_size - int(video_has_condition)
|
1133 |
+
if need_video2video:
|
1134 |
+
video = batch[:, :, 1:, :, :]
|
1135 |
+
else:
|
1136 |
+
video = None
|
1137 |
+
|
1138 |
+
result_overlap = 0
|
1139 |
+
else:
|
1140 |
+
actual_video_length = time_size
|
1141 |
+
if self.pipeline.controlnet is not None:
|
1142 |
+
if not fix_condition_images:
|
1143 |
+
logger.debug(
|
1144 |
+
f"{i_batch}, update first_image_controlnet_condition"
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
if not isinstance(last_batch_condition, list):
|
1148 |
+
first_image_controlnet_condition = last_batch_condition[
|
1149 |
+
:, :, -1:, :, :
|
1150 |
+
]
|
1151 |
+
else:
|
1152 |
+
first_image_controlnet_condition = [
|
1153 |
+
x[:, :, -1:, :, :] for x in last_batch_condition
|
1154 |
+
]
|
1155 |
+
else:
|
1156 |
+
logger.debug(
|
1157 |
+
f"{i_batch}, do not update first_image_controlnet_condition"
|
1158 |
+
)
|
1159 |
+
control_image = batch_condition
|
1160 |
+
else:
|
1161 |
+
control_image = None
|
1162 |
+
first_image_controlnet_condition = None
|
1163 |
+
if not fix_condition_images:
|
1164 |
+
logger.debug(f"{i_batch}, update condition_images")
|
1165 |
+
first_image_latents = out_latents_batch[:, :, -1:, :, :]
|
1166 |
+
else:
|
1167 |
+
logger.debug(f"{i_batch}, do not update condition_images")
|
1168 |
+
|
1169 |
+
if need_video2video:
|
1170 |
+
video = batch
|
1171 |
+
else:
|
1172 |
+
video = None
|
1173 |
+
result_overlap = 1
|
1174 |
+
|
1175 |
+
# 更新 ref_image和 ipadapter_image
|
1176 |
+
if not fixed_refer_image:
|
1177 |
+
logger.debug(
|
1178 |
+
"ref_image use last frame of last generated out video"
|
1179 |
+
)
|
1180 |
+
refer_image = (
|
1181 |
+
out_batch[:, :, -n_vision_condition:, :, :] * 255.0
|
1182 |
+
)
|
1183 |
+
else:
|
1184 |
+
logger.debug("use given fixed ref_image")
|
1185 |
+
|
1186 |
+
if not fixed_ip_adapter_image:
|
1187 |
+
logger.debug(
|
1188 |
+
"ip_adapter_image use last frame of last generated out video"
|
1189 |
+
)
|
1190 |
+
ip_adapter_image = (
|
1191 |
+
out_batch[:, :, -n_vision_condition:, :, :] * 255.0
|
1192 |
+
)
|
1193 |
+
else:
|
1194 |
+
logger.debug("use given fixed ip_adapter_image")
|
1195 |
+
|
1196 |
+
# face image
|
1197 |
+
if not fixed_ip_adapter_image:
|
1198 |
+
logger.debug(
|
1199 |
+
"refer_face_image use last frame of last generated out video"
|
1200 |
+
)
|
1201 |
+
refer_face_image = (
|
1202 |
+
out_batch[:, :, -n_vision_condition:, :, :] * 255.0
|
1203 |
+
)
|
1204 |
+
else:
|
1205 |
+
logger.debug("use given fixed ip_adapter_image")
|
1206 |
+
|
1207 |
+
out = self.pipeline(
|
1208 |
+
video_length=actual_video_length, # int
|
1209 |
+
prompt=prompt,
|
1210 |
+
num_inference_steps=video_num_inference_steps,
|
1211 |
+
height=height,
|
1212 |
+
width=width,
|
1213 |
+
generator=generator,
|
1214 |
+
image=video,
|
1215 |
+
control_image=control_image, # b ci(3) t hi wi
|
1216 |
+
controlnet_condition_images=first_image_controlnet_condition, # b ci(3) t(1) hi wi
|
1217 |
+
# controlnet_condition_images=np.zeros_like(
|
1218 |
+
# first_image_controlnet_condition
|
1219 |
+
# ), # b ci(3) t(1) hi wi
|
1220 |
+
condition_images=first_image,
|
1221 |
+
condition_latents=first_image_latents, # b co t(1) ho wo
|
1222 |
+
skip_temporal_layer=False,
|
1223 |
+
output_type="np",
|
1224 |
+
noise_type=noise_type,
|
1225 |
+
negative_prompt=video_negative_prompt,
|
1226 |
+
need_img_based_video_noise=need_img_based_video_noise,
|
1227 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
1228 |
+
control_guidance_start=control_guidance_start,
|
1229 |
+
control_guidance_end=control_guidance_end,
|
1230 |
+
w_ind_noise=w_ind_noise,
|
1231 |
+
img_weight=img_weight,
|
1232 |
+
motion_speed=video_reader.sample_rate,
|
1233 |
+
guidance_scale=video_guidance_scale,
|
1234 |
+
guidance_scale_end=video_guidance_scale_end,
|
1235 |
+
guidance_scale_method=video_guidance_scale_method,
|
1236 |
+
strength=video_strength,
|
1237 |
+
refer_image=refer_image,
|
1238 |
+
ip_adapter_image=ip_adapter_image,
|
1239 |
+
refer_face_image=refer_face_image,
|
1240 |
+
ip_adapter_scale=ip_adapter_scale,
|
1241 |
+
facein_scale=facein_scale,
|
1242 |
+
ip_adapter_face_scale=ip_adapter_face_scale,
|
1243 |
+
ip_adapter_face_image=refer_face_image,
|
1244 |
+
prompt_only_use_image_prompt=prompt_only_use_image_prompt,
|
1245 |
+
initial_common_latent=initial_common_latent,
|
1246 |
+
# serial_denoise parameter start
|
1247 |
+
record_mid_video_noises=record_mid_video_noises,
|
1248 |
+
last_mid_video_noises=last_mid_video_noises,
|
1249 |
+
record_mid_video_latents=record_mid_video_latents,
|
1250 |
+
last_mid_video_latents=last_mid_video_latents,
|
1251 |
+
video_overlap=video_overlap,
|
1252 |
+
# serial_denoise parameter end
|
1253 |
+
# parallel_denoise parameter start
|
1254 |
+
context_schedule=context_schedule,
|
1255 |
+
context_frames=context_frames,
|
1256 |
+
context_stride=context_stride,
|
1257 |
+
context_overlap=context_overlap,
|
1258 |
+
context_batch_size=context_batch_size,
|
1259 |
+
interpolation_factor=interpolation_factor,
|
1260 |
+
# parallel_denoise parameter end
|
1261 |
+
)
|
1262 |
+
last_batch = batch
|
1263 |
+
last_batch_condition = batch_condition
|
1264 |
+
last_mid_video_latents = out.mid_video_latents
|
1265 |
+
last_mid_video_noises = out.mid_video_noises
|
1266 |
+
out_batch = out.videos[:, :, result_overlap:, :, :]
|
1267 |
+
out_latents_batch = out.latents[:, :, result_overlap:, :, :]
|
1268 |
+
out_videos.append(out_batch)
|
1269 |
+
if need_return_videos:
|
1270 |
+
videos.append(batch)
|
1271 |
+
if out_condition is not None:
|
1272 |
+
out_condition.append(batch_condition)
|
1273 |
+
|
1274 |
+
out_videos = np.concatenate(out_videos, axis=2)
|
1275 |
+
if need_return_videos:
|
1276 |
+
videos = np.concatenate(videos, axis=2)
|
1277 |
+
if out_condition is not None:
|
1278 |
+
if not isinstance(out_condition[0], list):
|
1279 |
+
out_condition = np.concatenate(out_condition, axis=2)
|
1280 |
+
else:
|
1281 |
+
out_condition = [
|
1282 |
+
[out_condition[j][i] for j in range(len(out_condition))]
|
1283 |
+
for i in range(len(out_condition[0]))
|
1284 |
+
]
|
1285 |
+
out_condition = [np.concatenate(x, axis=2) for x in out_condition]
|
1286 |
+
if need_hist_match:
|
1287 |
+
videos[:, :, 1:, :, :] = hist_match_video_bcthw(
|
1288 |
+
videos[:, :, 1:, :, :], videos[:, :, :1, :, :], value=255.0
|
1289 |
+
)
|
1290 |
+
return out_videos, out_condition, videos
|
musev/schedulers/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
2 |
+
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
3 |
+
from .scheduling_euler_discrete import EulerDiscreteScheduler
|
4 |
+
from .scheduling_lcm import LCMScheduler
|
5 |
+
from .scheduling_ddim import DDIMScheduler
|
6 |
+
from .scheduling_ddpm import DDPMScheduler
|
musev/schedulers/scheduling_ddim.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
|
20 |
+
import math
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from typing import List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
from numpy import ndarray
|
26 |
+
import torch
|
27 |
+
|
28 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
29 |
+
from diffusers.utils import BaseOutput
|
30 |
+
from diffusers.utils.torch_utils import randn_tensor
|
31 |
+
from diffusers.schedulers.scheduling_utils import (
|
32 |
+
KarrasDiffusionSchedulers,
|
33 |
+
SchedulerMixin,
|
34 |
+
)
|
35 |
+
from diffusers.schedulers.scheduling_ddim import (
|
36 |
+
DDIMSchedulerOutput,
|
37 |
+
rescale_zero_terminal_snr,
|
38 |
+
betas_for_alpha_bar,
|
39 |
+
DDIMScheduler as DiffusersDDIMScheduler,
|
40 |
+
)
|
41 |
+
from ..utils.noise_util import video_fusion_noise
|
42 |
+
|
43 |
+
|
44 |
+
class DDIMScheduler(DiffusersDDIMScheduler):
|
45 |
+
"""
|
46 |
+
`DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
47 |
+
non-Markovian guidance.
|
48 |
+
|
49 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
50 |
+
methods the library implements for all schedulers such as loading and saving.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
num_train_timesteps (`int`, defaults to 1000):
|
54 |
+
The number of diffusion steps to train the model.
|
55 |
+
beta_start (`float`, defaults to 0.0001):
|
56 |
+
The starting `beta` value of inference.
|
57 |
+
beta_end (`float`, defaults to 0.02):
|
58 |
+
The final `beta` value.
|
59 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
60 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
61 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
62 |
+
trained_betas (`np.ndarray`, *optional*):
|
63 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
64 |
+
clip_sample (`bool`, defaults to `True`):
|
65 |
+
Clip the predicted sample for numerical stability.
|
66 |
+
clip_sample_range (`float`, defaults to 1.0):
|
67 |
+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
68 |
+
set_alpha_to_one (`bool`, defaults to `True`):
|
69 |
+
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
70 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
71 |
+
otherwise it uses the alpha value at step 0.
|
72 |
+
steps_offset (`int`, defaults to 0):
|
73 |
+
An offset added to the inference steps. You can use a combination of `offset=1` and
|
74 |
+
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
75 |
+
Diffusion.
|
76 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
77 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
78 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
79 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
80 |
+
thresholding (`bool`, defaults to `False`):
|
81 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
82 |
+
as Stable Diffusion.
|
83 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
84 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
85 |
+
sample_max_value (`float`, defaults to 1.0):
|
86 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
87 |
+
timestep_spacing (`str`, defaults to `"leading"`):
|
88 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
89 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
90 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
91 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
92 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
93 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
94 |
+
"""
|
95 |
+
|
96 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
97 |
+
order = 1
|
98 |
+
|
99 |
+
@register_to_config
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
num_train_timesteps: int = 1000,
|
103 |
+
beta_start: float = 0.0001,
|
104 |
+
beta_end: float = 0.02,
|
105 |
+
beta_schedule: str = "linear",
|
106 |
+
trained_betas: ndarray | List[float] | None = None,
|
107 |
+
clip_sample: bool = True,
|
108 |
+
set_alpha_to_one: bool = True,
|
109 |
+
steps_offset: int = 0,
|
110 |
+
prediction_type: str = "epsilon",
|
111 |
+
thresholding: bool = False,
|
112 |
+
dynamic_thresholding_ratio: float = 0.995,
|
113 |
+
clip_sample_range: float = 1,
|
114 |
+
sample_max_value: float = 1,
|
115 |
+
timestep_spacing: str = "leading",
|
116 |
+
rescale_betas_zero_snr: bool = False,
|
117 |
+
):
|
118 |
+
super().__init__(
|
119 |
+
num_train_timesteps,
|
120 |
+
beta_start,
|
121 |
+
beta_end,
|
122 |
+
beta_schedule,
|
123 |
+
trained_betas,
|
124 |
+
clip_sample,
|
125 |
+
set_alpha_to_one,
|
126 |
+
steps_offset,
|
127 |
+
prediction_type,
|
128 |
+
thresholding,
|
129 |
+
dynamic_thresholding_ratio,
|
130 |
+
clip_sample_range,
|
131 |
+
sample_max_value,
|
132 |
+
timestep_spacing,
|
133 |
+
rescale_betas_zero_snr,
|
134 |
+
)
|
135 |
+
|
136 |
+
def step(
|
137 |
+
self,
|
138 |
+
model_output: torch.FloatTensor,
|
139 |
+
timestep: int,
|
140 |
+
sample: torch.FloatTensor,
|
141 |
+
eta: float = 0.0,
|
142 |
+
use_clipped_model_output: bool = False,
|
143 |
+
generator=None,
|
144 |
+
variance_noise: Optional[torch.FloatTensor] = None,
|
145 |
+
return_dict: bool = True,
|
146 |
+
w_ind_noise: float = 0.5,
|
147 |
+
noise_type: str = "random",
|
148 |
+
) -> Union[DDIMSchedulerOutput, Tuple]:
|
149 |
+
"""
|
150 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
151 |
+
process from the learned model outputs (most often the predicted noise).
|
152 |
+
|
153 |
+
Args:
|
154 |
+
model_output (`torch.FloatTensor`):
|
155 |
+
The direct output from learned diffusion model.
|
156 |
+
timestep (`float`):
|
157 |
+
The current discrete timestep in the diffusion chain.
|
158 |
+
sample (`torch.FloatTensor`):
|
159 |
+
A current instance of a sample created by the diffusion process.
|
160 |
+
eta (`float`):
|
161 |
+
The weight of noise for added noise in diffusion step.
|
162 |
+
use_clipped_model_output (`bool`, defaults to `False`):
|
163 |
+
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
164 |
+
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
165 |
+
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
166 |
+
`use_clipped_model_output` has no effect.
|
167 |
+
generator (`torch.Generator`, *optional*):
|
168 |
+
A random number generator.
|
169 |
+
variance_noise (`torch.FloatTensor`):
|
170 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
171 |
+
itself. Useful for methods such as [`CycleDiffusion`].
|
172 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
173 |
+
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
|
177 |
+
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
|
178 |
+
tuple is returned where the first element is the sample tensor.
|
179 |
+
|
180 |
+
"""
|
181 |
+
if self.num_inference_steps is None:
|
182 |
+
raise ValueError(
|
183 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
184 |
+
)
|
185 |
+
|
186 |
+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
187 |
+
# Ideally, read DDIM paper in-detail understanding
|
188 |
+
|
189 |
+
# Notation (<variable name> -> <name in paper>
|
190 |
+
# - pred_noise_t -> e_theta(x_t, t)
|
191 |
+
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
192 |
+
# - std_dev_t -> sigma_t
|
193 |
+
# - eta -> η
|
194 |
+
# - pred_sample_direction -> "direction pointing to x_t"
|
195 |
+
# - pred_prev_sample -> "x_t-1"
|
196 |
+
|
197 |
+
# 1. get previous step value (=t-1)
|
198 |
+
prev_timestep = (
|
199 |
+
timestep - self.config.num_train_timesteps // self.num_inference_steps
|
200 |
+
)
|
201 |
+
|
202 |
+
# 2. compute alphas, betas
|
203 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
204 |
+
alpha_prod_t_prev = (
|
205 |
+
self.alphas_cumprod[prev_timestep]
|
206 |
+
if prev_timestep >= 0
|
207 |
+
else self.final_alpha_cumprod
|
208 |
+
)
|
209 |
+
|
210 |
+
beta_prod_t = 1 - alpha_prod_t
|
211 |
+
|
212 |
+
# 3. compute predicted original sample from predicted noise also called
|
213 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
214 |
+
if self.config.prediction_type == "epsilon":
|
215 |
+
pred_original_sample = (
|
216 |
+
sample - beta_prod_t ** (0.5) * model_output
|
217 |
+
) / alpha_prod_t ** (0.5)
|
218 |
+
pred_epsilon = model_output
|
219 |
+
elif self.config.prediction_type == "sample":
|
220 |
+
pred_original_sample = model_output
|
221 |
+
pred_epsilon = (
|
222 |
+
sample - alpha_prod_t ** (0.5) * pred_original_sample
|
223 |
+
) / beta_prod_t ** (0.5)
|
224 |
+
elif self.config.prediction_type == "v_prediction":
|
225 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (
|
226 |
+
beta_prod_t**0.5
|
227 |
+
) * model_output
|
228 |
+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (
|
229 |
+
beta_prod_t**0.5
|
230 |
+
) * sample
|
231 |
+
else:
|
232 |
+
raise ValueError(
|
233 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
234 |
+
" `v_prediction`"
|
235 |
+
)
|
236 |
+
|
237 |
+
# 4. Clip or threshold "predicted x_0"
|
238 |
+
if self.config.thresholding:
|
239 |
+
pred_original_sample = self._threshold_sample(pred_original_sample)
|
240 |
+
elif self.config.clip_sample:
|
241 |
+
pred_original_sample = pred_original_sample.clamp(
|
242 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
243 |
+
)
|
244 |
+
|
245 |
+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
246 |
+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
247 |
+
variance = self._get_variance(timestep, prev_timestep)
|
248 |
+
std_dev_t = eta * variance ** (0.5)
|
249 |
+
|
250 |
+
if use_clipped_model_output:
|
251 |
+
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
252 |
+
pred_epsilon = (
|
253 |
+
sample - alpha_prod_t ** (0.5) * pred_original_sample
|
254 |
+
) / beta_prod_t ** (0.5)
|
255 |
+
|
256 |
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
257 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
|
258 |
+
0.5
|
259 |
+
) * pred_epsilon
|
260 |
+
|
261 |
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
262 |
+
prev_sample = (
|
263 |
+
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
264 |
+
)
|
265 |
+
|
266 |
+
if eta > 0:
|
267 |
+
if variance_noise is not None and generator is not None:
|
268 |
+
raise ValueError(
|
269 |
+
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
|
270 |
+
" `variance_noise` stays `None`."
|
271 |
+
)
|
272 |
+
|
273 |
+
# if variance_noise is None:
|
274 |
+
# variance_noise = randn_tensor(
|
275 |
+
# model_output.shape,
|
276 |
+
# generator=generator,
|
277 |
+
# device=model_output.device,
|
278 |
+
# dtype=model_output.dtype,
|
279 |
+
# )
|
280 |
+
device = model_output.device
|
281 |
+
|
282 |
+
if noise_type == "random":
|
283 |
+
variance_noise = randn_tensor(
|
284 |
+
model_output.shape,
|
285 |
+
dtype=model_output.dtype,
|
286 |
+
device=device,
|
287 |
+
generator=generator,
|
288 |
+
)
|
289 |
+
elif noise_type == "video_fusion":
|
290 |
+
variance_noise = video_fusion_noise(
|
291 |
+
model_output, w_ind_noise=w_ind_noise, generator=generator
|
292 |
+
)
|
293 |
+
variance = std_dev_t * variance_noise
|
294 |
+
|
295 |
+
prev_sample = prev_sample + variance
|
296 |
+
|
297 |
+
if not return_dict:
|
298 |
+
return (prev_sample,)
|
299 |
+
|
300 |
+
return DDIMSchedulerOutput(
|
301 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
302 |
+
)
|
musev/schedulers/scheduling_ddpm.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
16 |
+
|
17 |
+
from __future__ import annotations
|
18 |
+
|
19 |
+
import math
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
from numpy import ndarray
|
25 |
+
import torch
|
26 |
+
|
27 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
28 |
+
from diffusers.utils import BaseOutput
|
29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
30 |
+
from diffusers.schedulers.scheduling_utils import (
|
31 |
+
KarrasDiffusionSchedulers,
|
32 |
+
SchedulerMixin,
|
33 |
+
)
|
34 |
+
from diffusers.schedulers.scheduling_ddpm import (
|
35 |
+
DDPMSchedulerOutput,
|
36 |
+
betas_for_alpha_bar,
|
37 |
+
DDPMScheduler as DiffusersDDPMScheduler,
|
38 |
+
)
|
39 |
+
from ..utils.noise_util import video_fusion_noise
|
40 |
+
|
41 |
+
|
42 |
+
class DDPMScheduler(DiffusersDDPMScheduler):
|
43 |
+
"""
|
44 |
+
`DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
|
45 |
+
|
46 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
47 |
+
methods the library implements for all schedulers such as loading and saving.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
num_train_timesteps (`int`, defaults to 1000):
|
51 |
+
The number of diffusion steps to train the model.
|
52 |
+
beta_start (`float`, defaults to 0.0001):
|
53 |
+
The starting `beta` value of inference.
|
54 |
+
beta_end (`float`, defaults to 0.02):
|
55 |
+
The final `beta` value.
|
56 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
57 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
58 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
59 |
+
variance_type (`str`, defaults to `"fixed_small"`):
|
60 |
+
Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`,
|
61 |
+
`fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
62 |
+
clip_sample (`bool`, defaults to `True`):
|
63 |
+
Clip the predicted sample for numerical stability.
|
64 |
+
clip_sample_range (`float`, defaults to 1.0):
|
65 |
+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
66 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
67 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
68 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
69 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
70 |
+
thresholding (`bool`, defaults to `False`):
|
71 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
72 |
+
as Stable Diffusion.
|
73 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
74 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
75 |
+
sample_max_value (`float`, defaults to 1.0):
|
76 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
77 |
+
timestep_spacing (`str`, defaults to `"leading"`):
|
78 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
79 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
80 |
+
steps_offset (`int`, defaults to 0):
|
81 |
+
An offset added to the inference steps. You can use a combination of `offset=1` and
|
82 |
+
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
83 |
+
Diffusion.
|
84 |
+
"""
|
85 |
+
|
86 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
87 |
+
order = 1
|
88 |
+
|
89 |
+
@register_to_config
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
num_train_timesteps: int = 1000,
|
93 |
+
beta_start: float = 0.0001,
|
94 |
+
beta_end: float = 0.02,
|
95 |
+
beta_schedule: str = "linear",
|
96 |
+
trained_betas: ndarray | List[float] | None = None,
|
97 |
+
variance_type: str = "fixed_small",
|
98 |
+
clip_sample: bool = True,
|
99 |
+
prediction_type: str = "epsilon",
|
100 |
+
thresholding: bool = False,
|
101 |
+
dynamic_thresholding_ratio: float = 0.995,
|
102 |
+
clip_sample_range: float = 1,
|
103 |
+
sample_max_value: float = 1,
|
104 |
+
timestep_spacing: str = "leading",
|
105 |
+
steps_offset: int = 0,
|
106 |
+
):
|
107 |
+
super().__init__(
|
108 |
+
num_train_timesteps,
|
109 |
+
beta_start,
|
110 |
+
beta_end,
|
111 |
+
beta_schedule,
|
112 |
+
trained_betas,
|
113 |
+
variance_type,
|
114 |
+
clip_sample,
|
115 |
+
prediction_type,
|
116 |
+
thresholding,
|
117 |
+
dynamic_thresholding_ratio,
|
118 |
+
clip_sample_range,
|
119 |
+
sample_max_value,
|
120 |
+
timestep_spacing,
|
121 |
+
steps_offset,
|
122 |
+
)
|
123 |
+
|
124 |
+
def step(
|
125 |
+
self,
|
126 |
+
model_output: torch.FloatTensor,
|
127 |
+
timestep: int,
|
128 |
+
sample: torch.FloatTensor,
|
129 |
+
generator=None,
|
130 |
+
return_dict: bool = True,
|
131 |
+
w_ind_noise: float = 0.5,
|
132 |
+
noise_type: str = "random",
|
133 |
+
) -> Union[DDPMSchedulerOutput, Tuple]:
|
134 |
+
"""
|
135 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
136 |
+
process from the learned model outputs (most often the predicted noise).
|
137 |
+
|
138 |
+
Args:
|
139 |
+
model_output (`torch.FloatTensor`):
|
140 |
+
The direct output from learned diffusion model.
|
141 |
+
timestep (`float`):
|
142 |
+
The current discrete timestep in the diffusion chain.
|
143 |
+
sample (`torch.FloatTensor`):
|
144 |
+
A current instance of a sample created by the diffusion process.
|
145 |
+
generator (`torch.Generator`, *optional*):
|
146 |
+
A random number generator.
|
147 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
148 |
+
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
|
152 |
+
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
|
153 |
+
tuple is returned where the first element is the sample tensor.
|
154 |
+
|
155 |
+
"""
|
156 |
+
t = timestep
|
157 |
+
|
158 |
+
prev_t = self.previous_timestep(t)
|
159 |
+
|
160 |
+
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
|
161 |
+
"learned",
|
162 |
+
"learned_range",
|
163 |
+
]:
|
164 |
+
model_output, predicted_variance = torch.split(
|
165 |
+
model_output, sample.shape[1], dim=1
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
predicted_variance = None
|
169 |
+
|
170 |
+
# 1. compute alphas, betas
|
171 |
+
alpha_prod_t = self.alphas_cumprod[t]
|
172 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
173 |
+
beta_prod_t = 1 - alpha_prod_t
|
174 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
175 |
+
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
176 |
+
current_beta_t = 1 - current_alpha_t
|
177 |
+
|
178 |
+
# 2. compute predicted original sample from predicted noise also called
|
179 |
+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
180 |
+
if self.config.prediction_type == "epsilon":
|
181 |
+
pred_original_sample = (
|
182 |
+
sample - beta_prod_t ** (0.5) * model_output
|
183 |
+
) / alpha_prod_t ** (0.5)
|
184 |
+
elif self.config.prediction_type == "sample":
|
185 |
+
pred_original_sample = model_output
|
186 |
+
elif self.config.prediction_type == "v_prediction":
|
187 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (
|
188 |
+
beta_prod_t**0.5
|
189 |
+
) * model_output
|
190 |
+
else:
|
191 |
+
raise ValueError(
|
192 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
193 |
+
" `v_prediction` for the DDPMScheduler."
|
194 |
+
)
|
195 |
+
|
196 |
+
# 3. Clip or threshold "predicted x_0"
|
197 |
+
if self.config.thresholding:
|
198 |
+
pred_original_sample = self._threshold_sample(pred_original_sample)
|
199 |
+
elif self.config.clip_sample:
|
200 |
+
pred_original_sample = pred_original_sample.clamp(
|
201 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
202 |
+
)
|
203 |
+
|
204 |
+
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
205 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
206 |
+
pred_original_sample_coeff = (
|
207 |
+
alpha_prod_t_prev ** (0.5) * current_beta_t
|
208 |
+
) / beta_prod_t
|
209 |
+
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
|
210 |
+
|
211 |
+
# 5. Compute predicted previous sample µ_t
|
212 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
213 |
+
pred_prev_sample = (
|
214 |
+
pred_original_sample_coeff * pred_original_sample
|
215 |
+
+ current_sample_coeff * sample
|
216 |
+
)
|
217 |
+
|
218 |
+
# 6. Add noise
|
219 |
+
variance = 0
|
220 |
+
if t > 0:
|
221 |
+
device = model_output.device
|
222 |
+
# if variance_noise is None:
|
223 |
+
# variance_noise = randn_tensor(
|
224 |
+
# model_output.shape,
|
225 |
+
# generator=generator,
|
226 |
+
# device=model_output.device,
|
227 |
+
# dtype=model_output.dtype,
|
228 |
+
# )
|
229 |
+
device = model_output.device
|
230 |
+
|
231 |
+
if noise_type == "random":
|
232 |
+
variance_noise = randn_tensor(
|
233 |
+
model_output.shape,
|
234 |
+
dtype=model_output.dtype,
|
235 |
+
device=device,
|
236 |
+
generator=generator,
|
237 |
+
)
|
238 |
+
elif noise_type == "video_fusion":
|
239 |
+
variance_noise = video_fusion_noise(
|
240 |
+
model_output, w_ind_noise=w_ind_noise, generator=generator
|
241 |
+
)
|
242 |
+
if self.variance_type == "fixed_small_log":
|
243 |
+
variance = (
|
244 |
+
self._get_variance(t, predicted_variance=predicted_variance)
|
245 |
+
* variance_noise
|
246 |
+
)
|
247 |
+
elif self.variance_type == "learned_range":
|
248 |
+
variance = self._get_variance(t, predicted_variance=predicted_variance)
|
249 |
+
variance = torch.exp(0.5 * variance) * variance_noise
|
250 |
+
else:
|
251 |
+
variance = (
|
252 |
+
self._get_variance(t, predicted_variance=predicted_variance) ** 0.5
|
253 |
+
) * variance_noise
|
254 |
+
|
255 |
+
pred_prev_sample = pred_prev_sample + variance
|
256 |
+
|
257 |
+
if not return_dict:
|
258 |
+
return (pred_prev_sample,)
|
259 |
+
|
260 |
+
return DDPMSchedulerOutput(
|
261 |
+
prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample
|
262 |
+
)
|
musev/schedulers/scheduling_dpmsolver_multistep.py
ADDED
@@ -0,0 +1,815 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
16 |
+
|
17 |
+
import math
|
18 |
+
from typing import List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
|
25 |
+
try:
|
26 |
+
from diffusers.utils import randn_tensor
|
27 |
+
except:
|
28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
29 |
+
from diffusers.schedulers.scheduling_utils import (
|
30 |
+
KarrasDiffusionSchedulers,
|
31 |
+
SchedulerMixin,
|
32 |
+
SchedulerOutput,
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
37 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
38 |
+
"""
|
39 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
40 |
+
(1-beta) over time from t = [0,1].
|
41 |
+
|
42 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
43 |
+
to that part of the diffusion process.
|
44 |
+
|
45 |
+
|
46 |
+
Args:
|
47 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
48 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
49 |
+
prevent singularities.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
53 |
+
"""
|
54 |
+
|
55 |
+
def alpha_bar(time_step):
|
56 |
+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
57 |
+
|
58 |
+
betas = []
|
59 |
+
for i in range(num_diffusion_timesteps):
|
60 |
+
t1 = i / num_diffusion_timesteps
|
61 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
62 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
63 |
+
return torch.tensor(betas, dtype=torch.float32)
|
64 |
+
|
65 |
+
|
66 |
+
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
67 |
+
"""
|
68 |
+
DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
|
69 |
+
the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
|
70 |
+
samples, and it can generate quite good samples even in only 10 steps.
|
71 |
+
|
72 |
+
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
|
73 |
+
|
74 |
+
Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
|
75 |
+
recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
|
76 |
+
|
77 |
+
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
|
78 |
+
diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
|
79 |
+
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
80 |
+
stable-diffusion).
|
81 |
+
|
82 |
+
We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse
|
83 |
+
diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the
|
84 |
+
second-order `sde-dpmsolver++`.
|
85 |
+
|
86 |
+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
87 |
+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
88 |
+
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
89 |
+
[`~SchedulerMixin.from_pretrained`] functions.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
93 |
+
beta_start (`float`): the starting `beta` value of inference.
|
94 |
+
beta_end (`float`): the final `beta` value.
|
95 |
+
beta_schedule (`str`):
|
96 |
+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
97 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
98 |
+
trained_betas (`np.ndarray`, optional):
|
99 |
+
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
100 |
+
solver_order (`int`, default `2`):
|
101 |
+
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
102 |
+
sampling, and `solver_order=3` for unconditional sampling.
|
103 |
+
prediction_type (`str`, default `epsilon`, optional):
|
104 |
+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
105 |
+
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
106 |
+
https://imagen.research.google/video/paper.pdf)
|
107 |
+
thresholding (`bool`, default `False`):
|
108 |
+
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
109 |
+
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
110 |
+
use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
|
111 |
+
models (such as stable-diffusion).
|
112 |
+
dynamic_thresholding_ratio (`float`, default `0.995`):
|
113 |
+
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
114 |
+
(https://arxiv.org/abs/2205.11487).
|
115 |
+
sample_max_value (`float`, default `1.0`):
|
116 |
+
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
117 |
+
`algorithm_type="dpmsolver++`.
|
118 |
+
algorithm_type (`str`, default `dpmsolver++`):
|
119 |
+
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or
|
120 |
+
`sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
|
121 |
+
the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
|
122 |
+
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
|
123 |
+
solver_type (`str`, default `midpoint`):
|
124 |
+
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
|
125 |
+
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
|
126 |
+
slightly better, so we recommend to use the `midpoint` type.
|
127 |
+
lower_order_final (`bool`, default `True`):
|
128 |
+
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
|
129 |
+
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
|
130 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
131 |
+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
|
132 |
+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
|
133 |
+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
|
134 |
+
lambda_min_clipped (`float`, default `-inf`):
|
135 |
+
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
|
136 |
+
cosine (squaredcos_cap_v2) noise schedule.
|
137 |
+
variance_type (`str`, *optional*):
|
138 |
+
Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
|
139 |
+
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
|
140 |
+
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
|
141 |
+
diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
|
142 |
+
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
|
143 |
+
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
|
144 |
+
diffusion ODEs.
|
145 |
+
"""
|
146 |
+
|
147 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
148 |
+
order = 1
|
149 |
+
|
150 |
+
@register_to_config
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
num_train_timesteps: int = 1000,
|
154 |
+
beta_start: float = 0.0001,
|
155 |
+
beta_end: float = 0.02,
|
156 |
+
beta_schedule: str = "linear",
|
157 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
158 |
+
solver_order: int = 2,
|
159 |
+
prediction_type: str = "epsilon",
|
160 |
+
thresholding: bool = False,
|
161 |
+
dynamic_thresholding_ratio: float = 0.995,
|
162 |
+
sample_max_value: float = 1.0,
|
163 |
+
algorithm_type: str = "dpmsolver++",
|
164 |
+
solver_type: str = "midpoint",
|
165 |
+
lower_order_final: bool = True,
|
166 |
+
use_karras_sigmas: Optional[bool] = True,
|
167 |
+
lambda_min_clipped: float = -float("inf"),
|
168 |
+
variance_type: Optional[str] = None,
|
169 |
+
):
|
170 |
+
if trained_betas is not None:
|
171 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
172 |
+
elif beta_schedule == "linear":
|
173 |
+
self.betas = torch.linspace(
|
174 |
+
beta_start, beta_end, num_train_timesteps, dtype=torch.float32
|
175 |
+
)
|
176 |
+
elif beta_schedule == "scaled_linear":
|
177 |
+
# this schedule is very specific to the latent diffusion model.
|
178 |
+
self.betas = (
|
179 |
+
torch.linspace(
|
180 |
+
beta_start**0.5,
|
181 |
+
beta_end**0.5,
|
182 |
+
num_train_timesteps,
|
183 |
+
dtype=torch.float32,
|
184 |
+
)
|
185 |
+
** 2
|
186 |
+
)
|
187 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
188 |
+
# Glide cosine schedule
|
189 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
190 |
+
else:
|
191 |
+
raise NotImplementedError(
|
192 |
+
f"{beta_schedule} does is not implemented for {self.__class__}"
|
193 |
+
)
|
194 |
+
|
195 |
+
self.alphas = 1.0 - self.betas
|
196 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
197 |
+
# Currently we only support VP-type noise schedule
|
198 |
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
199 |
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
200 |
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
201 |
+
|
202 |
+
# standard deviation of the initial noise distribution
|
203 |
+
self.init_noise_sigma = 1.0
|
204 |
+
|
205 |
+
# settings for DPM-Solver
|
206 |
+
if algorithm_type not in [
|
207 |
+
"dpmsolver",
|
208 |
+
"dpmsolver++",
|
209 |
+
"sde-dpmsolver",
|
210 |
+
"sde-dpmsolver++",
|
211 |
+
]:
|
212 |
+
if algorithm_type == "deis":
|
213 |
+
self.register_to_config(algorithm_type="dpmsolver++")
|
214 |
+
else:
|
215 |
+
raise NotImplementedError(
|
216 |
+
f"{algorithm_type} does is not implemented for {self.__class__}"
|
217 |
+
)
|
218 |
+
|
219 |
+
if solver_type not in ["midpoint", "heun"]:
|
220 |
+
if solver_type in ["logrho", "bh1", "bh2"]:
|
221 |
+
self.register_to_config(solver_type="midpoint")
|
222 |
+
else:
|
223 |
+
raise NotImplementedError(
|
224 |
+
f"{solver_type} does is not implemented for {self.__class__}"
|
225 |
+
)
|
226 |
+
|
227 |
+
# setable values
|
228 |
+
self.num_inference_steps = None
|
229 |
+
timesteps = np.linspace(
|
230 |
+
0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32
|
231 |
+
)[::-1].copy()
|
232 |
+
self.timesteps = torch.from_numpy(timesteps)
|
233 |
+
self.model_outputs = [None] * solver_order
|
234 |
+
self.lower_order_nums = 0
|
235 |
+
self.use_karras_sigmas = use_karras_sigmas
|
236 |
+
|
237 |
+
def set_timesteps(
|
238 |
+
self, num_inference_steps: int = None, device: Union[str, torch.device] = None
|
239 |
+
):
|
240 |
+
"""
|
241 |
+
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
num_inference_steps (`int`):
|
245 |
+
the number of diffusion steps used when generating samples with a pre-trained model.
|
246 |
+
device (`str` or `torch.device`, optional):
|
247 |
+
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
248 |
+
"""
|
249 |
+
# Clipping the minimum of all lambda(t) for numerical stability.
|
250 |
+
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
251 |
+
clipped_idx = torch.searchsorted(
|
252 |
+
torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped
|
253 |
+
)
|
254 |
+
timesteps = (
|
255 |
+
np.linspace(
|
256 |
+
0,
|
257 |
+
self.config.num_train_timesteps - 1 - clipped_idx,
|
258 |
+
num_inference_steps + 1,
|
259 |
+
)
|
260 |
+
.round()[::-1][:-1]
|
261 |
+
.copy()
|
262 |
+
.astype(np.int64)
|
263 |
+
)
|
264 |
+
|
265 |
+
if self.use_karras_sigmas:
|
266 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
267 |
+
log_sigmas = np.log(sigmas)
|
268 |
+
sigmas = self._convert_to_karras(
|
269 |
+
in_sigmas=sigmas, num_inference_steps=num_inference_steps
|
270 |
+
)
|
271 |
+
timesteps = np.array(
|
272 |
+
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
|
273 |
+
).round()
|
274 |
+
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
275 |
+
|
276 |
+
# when num_inference_steps == num_train_timesteps, we can end up with
|
277 |
+
# duplicates in timesteps.
|
278 |
+
_, unique_indices = np.unique(timesteps, return_index=True)
|
279 |
+
timesteps = timesteps[np.sort(unique_indices)]
|
280 |
+
|
281 |
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
282 |
+
|
283 |
+
self.num_inference_steps = len(timesteps)
|
284 |
+
|
285 |
+
self.model_outputs = [
|
286 |
+
None,
|
287 |
+
] * self.config.solver_order
|
288 |
+
self.lower_order_nums = 0
|
289 |
+
|
290 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
291 |
+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
292 |
+
"""
|
293 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
294 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
295 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
296 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
297 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
298 |
+
|
299 |
+
https://arxiv.org/abs/2205.11487
|
300 |
+
"""
|
301 |
+
dtype = sample.dtype
|
302 |
+
batch_size, channels, height, width = sample.shape
|
303 |
+
|
304 |
+
if dtype not in (torch.float32, torch.float64):
|
305 |
+
sample = (
|
306 |
+
sample.float()
|
307 |
+
) # upcast for quantile calculation, and clamp not implemented for cpu half
|
308 |
+
|
309 |
+
# Flatten sample for doing quantile calculation along each image
|
310 |
+
sample = sample.reshape(batch_size, channels * height * width)
|
311 |
+
|
312 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
313 |
+
|
314 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
315 |
+
s = torch.clamp(
|
316 |
+
s, min=1, max=self.config.sample_max_value
|
317 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
318 |
+
|
319 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
320 |
+
sample = (
|
321 |
+
torch.clamp(sample, -s, s) / s
|
322 |
+
) # "we threshold xt0 to the range [-s, s] and then divide by s"
|
323 |
+
|
324 |
+
sample = sample.reshape(batch_size, channels, height, width)
|
325 |
+
sample = sample.to(dtype)
|
326 |
+
|
327 |
+
return sample
|
328 |
+
|
329 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
330 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
331 |
+
# get log sigma
|
332 |
+
log_sigma = np.log(sigma)
|
333 |
+
|
334 |
+
# get distribution
|
335 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
336 |
+
|
337 |
+
# get sigmas range
|
338 |
+
low_idx = (
|
339 |
+
np.cumsum((dists >= 0), axis=0)
|
340 |
+
.argmax(axis=0)
|
341 |
+
.clip(max=log_sigmas.shape[0] - 2)
|
342 |
+
)
|
343 |
+
high_idx = low_idx + 1
|
344 |
+
|
345 |
+
low = log_sigmas[low_idx]
|
346 |
+
high = log_sigmas[high_idx]
|
347 |
+
|
348 |
+
# interpolate sigmas
|
349 |
+
w = (low - log_sigma) / (low - high)
|
350 |
+
w = np.clip(w, 0, 1)
|
351 |
+
|
352 |
+
# transform interpolation to time range
|
353 |
+
t = (1 - w) * low_idx + w * high_idx
|
354 |
+
t = t.reshape(sigma.shape)
|
355 |
+
return t
|
356 |
+
|
357 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
358 |
+
def _convert_to_karras(
|
359 |
+
self, in_sigmas: torch.FloatTensor, num_inference_steps
|
360 |
+
) -> torch.FloatTensor:
|
361 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
362 |
+
|
363 |
+
sigma_min: float = in_sigmas[-1].item()
|
364 |
+
sigma_max: float = in_sigmas[0].item()
|
365 |
+
|
366 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
367 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
368 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
369 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
370 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
371 |
+
return sigmas
|
372 |
+
|
373 |
+
def convert_model_output(
|
374 |
+
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
375 |
+
) -> torch.FloatTensor:
|
376 |
+
"""
|
377 |
+
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
|
378 |
+
|
379 |
+
DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
|
380 |
+
discretize an integral of the data prediction model. So we need to first convert the model output to the
|
381 |
+
corresponding type to match the algorithm.
|
382 |
+
|
383 |
+
Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
|
384 |
+
DPM-Solver++ for both noise prediction model and data prediction model.
|
385 |
+
|
386 |
+
Args:
|
387 |
+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
388 |
+
timestep (`int`): current discrete timestep in the diffusion chain.
|
389 |
+
sample (`torch.FloatTensor`):
|
390 |
+
current instance of sample being created by diffusion process.
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
`torch.FloatTensor`: the converted model output.
|
394 |
+
"""
|
395 |
+
|
396 |
+
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
397 |
+
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
398 |
+
if self.config.prediction_type == "epsilon":
|
399 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
400 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
401 |
+
model_output = model_output[:, :3]
|
402 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
403 |
+
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
404 |
+
elif self.config.prediction_type == "sample":
|
405 |
+
x0_pred = model_output
|
406 |
+
elif self.config.prediction_type == "v_prediction":
|
407 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
408 |
+
x0_pred = alpha_t * sample - sigma_t * model_output
|
409 |
+
else:
|
410 |
+
raise ValueError(
|
411 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
412 |
+
" `v_prediction` for the DPMSolverMultistepScheduler."
|
413 |
+
)
|
414 |
+
|
415 |
+
if self.config.thresholding:
|
416 |
+
x0_pred = self._threshold_sample(x0_pred)
|
417 |
+
|
418 |
+
return x0_pred
|
419 |
+
|
420 |
+
# DPM-Solver needs to solve an integral of the noise prediction model.
|
421 |
+
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
422 |
+
if self.config.prediction_type == "epsilon":
|
423 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
424 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
425 |
+
epsilon = model_output[:, :3]
|
426 |
+
else:
|
427 |
+
epsilon = model_output
|
428 |
+
elif self.config.prediction_type == "sample":
|
429 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
430 |
+
epsilon = (sample - alpha_t * model_output) / sigma_t
|
431 |
+
elif self.config.prediction_type == "v_prediction":
|
432 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
433 |
+
epsilon = alpha_t * model_output + sigma_t * sample
|
434 |
+
else:
|
435 |
+
raise ValueError(
|
436 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
437 |
+
" `v_prediction` for the DPMSolverMultistepScheduler."
|
438 |
+
)
|
439 |
+
|
440 |
+
if self.config.thresholding:
|
441 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
442 |
+
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
443 |
+
x0_pred = self._threshold_sample(x0_pred)
|
444 |
+
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
445 |
+
|
446 |
+
return epsilon
|
447 |
+
|
448 |
+
def dpm_solver_first_order_update(
|
449 |
+
self,
|
450 |
+
model_output: torch.FloatTensor,
|
451 |
+
timestep: int,
|
452 |
+
prev_timestep: int,
|
453 |
+
sample: torch.FloatTensor,
|
454 |
+
noise: Optional[torch.FloatTensor] = None,
|
455 |
+
) -> torch.FloatTensor:
|
456 |
+
"""
|
457 |
+
One step for the first-order DPM-Solver (equivalent to DDIM).
|
458 |
+
|
459 |
+
See https://arxiv.org/abs/2206.00927 for the detailed derivation.
|
460 |
+
|
461 |
+
Args:
|
462 |
+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
463 |
+
timestep (`int`): current discrete timestep in the diffusion chain.
|
464 |
+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
465 |
+
sample (`torch.FloatTensor`):
|
466 |
+
current instance of sample being created by diffusion process.
|
467 |
+
|
468 |
+
Returns:
|
469 |
+
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
470 |
+
"""
|
471 |
+
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
472 |
+
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
473 |
+
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
474 |
+
h = lambda_t - lambda_s
|
475 |
+
if self.config.algorithm_type == "dpmsolver++":
|
476 |
+
x_t = (sigma_t / sigma_s) * sample - (
|
477 |
+
alpha_t * (torch.exp(-h) - 1.0)
|
478 |
+
) * model_output
|
479 |
+
elif self.config.algorithm_type == "dpmsolver":
|
480 |
+
x_t = (alpha_t / alpha_s) * sample - (
|
481 |
+
sigma_t * (torch.exp(h) - 1.0)
|
482 |
+
) * model_output
|
483 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
484 |
+
assert noise is not None
|
485 |
+
x_t = (
|
486 |
+
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
487 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
488 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
489 |
+
)
|
490 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
491 |
+
assert noise is not None
|
492 |
+
x_t = (
|
493 |
+
(alpha_t / alpha_s) * sample
|
494 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
495 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
496 |
+
)
|
497 |
+
return x_t
|
498 |
+
|
499 |
+
def multistep_dpm_solver_second_order_update(
|
500 |
+
self,
|
501 |
+
model_output_list: List[torch.FloatTensor],
|
502 |
+
timestep_list: List[int],
|
503 |
+
prev_timestep: int,
|
504 |
+
sample: torch.FloatTensor,
|
505 |
+
noise: Optional[torch.FloatTensor] = None,
|
506 |
+
) -> torch.FloatTensor:
|
507 |
+
"""
|
508 |
+
One step for the second-order multistep DPM-Solver.
|
509 |
+
|
510 |
+
Args:
|
511 |
+
model_output_list (`List[torch.FloatTensor]`):
|
512 |
+
direct outputs from learned diffusion model at current and latter timesteps.
|
513 |
+
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
514 |
+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
515 |
+
sample (`torch.FloatTensor`):
|
516 |
+
current instance of sample being created by diffusion process.
|
517 |
+
|
518 |
+
Returns:
|
519 |
+
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
520 |
+
"""
|
521 |
+
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
522 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
523 |
+
lambda_t, lambda_s0, lambda_s1 = (
|
524 |
+
self.lambda_t[t],
|
525 |
+
self.lambda_t[s0],
|
526 |
+
self.lambda_t[s1],
|
527 |
+
)
|
528 |
+
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
529 |
+
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
530 |
+
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
531 |
+
r0 = h_0 / h
|
532 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
533 |
+
if self.config.algorithm_type == "dpmsolver++":
|
534 |
+
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
535 |
+
if self.config.solver_type == "midpoint":
|
536 |
+
x_t = (
|
537 |
+
(sigma_t / sigma_s0) * sample
|
538 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
539 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
540 |
+
)
|
541 |
+
elif self.config.solver_type == "heun":
|
542 |
+
x_t = (
|
543 |
+
(sigma_t / sigma_s0) * sample
|
544 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
545 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
546 |
+
)
|
547 |
+
elif self.config.algorithm_type == "dpmsolver":
|
548 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
549 |
+
if self.config.solver_type == "midpoint":
|
550 |
+
x_t = (
|
551 |
+
(alpha_t / alpha_s0) * sample
|
552 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
553 |
+
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
554 |
+
)
|
555 |
+
elif self.config.solver_type == "heun":
|
556 |
+
x_t = (
|
557 |
+
(alpha_t / alpha_s0) * sample
|
558 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
559 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
560 |
+
)
|
561 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
562 |
+
assert noise is not None
|
563 |
+
if self.config.solver_type == "midpoint":
|
564 |
+
x_t = (
|
565 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
566 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
567 |
+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
568 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
569 |
+
)
|
570 |
+
elif self.config.solver_type == "heun":
|
571 |
+
x_t = (
|
572 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
573 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
574 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
575 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
576 |
+
)
|
577 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
578 |
+
assert noise is not None
|
579 |
+
if self.config.solver_type == "midpoint":
|
580 |
+
x_t = (
|
581 |
+
(alpha_t / alpha_s0) * sample
|
582 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
583 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
584 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
585 |
+
)
|
586 |
+
elif self.config.solver_type == "heun":
|
587 |
+
x_t = (
|
588 |
+
(alpha_t / alpha_s0) * sample
|
589 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
590 |
+
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
591 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
592 |
+
)
|
593 |
+
return x_t
|
594 |
+
|
595 |
+
def multistep_dpm_solver_third_order_update(
|
596 |
+
self,
|
597 |
+
model_output_list: List[torch.FloatTensor],
|
598 |
+
timestep_list: List[int],
|
599 |
+
prev_timestep: int,
|
600 |
+
sample: torch.FloatTensor,
|
601 |
+
) -> torch.FloatTensor:
|
602 |
+
"""
|
603 |
+
One step for the third-order multistep DPM-Solver.
|
604 |
+
|
605 |
+
Args:
|
606 |
+
model_output_list (`List[torch.FloatTensor]`):
|
607 |
+
direct outputs from learned diffusion model at current and latter timesteps.
|
608 |
+
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
609 |
+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
610 |
+
sample (`torch.FloatTensor`):
|
611 |
+
current instance of sample being created by diffusion process.
|
612 |
+
|
613 |
+
Returns:
|
614 |
+
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
615 |
+
"""
|
616 |
+
t, s0, s1, s2 = (
|
617 |
+
prev_timestep,
|
618 |
+
timestep_list[-1],
|
619 |
+
timestep_list[-2],
|
620 |
+
timestep_list[-3],
|
621 |
+
)
|
622 |
+
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
623 |
+
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
624 |
+
self.lambda_t[t],
|
625 |
+
self.lambda_t[s0],
|
626 |
+
self.lambda_t[s1],
|
627 |
+
self.lambda_t[s2],
|
628 |
+
)
|
629 |
+
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
630 |
+
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
631 |
+
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
632 |
+
r0, r1 = h_0 / h, h_1 / h
|
633 |
+
D0 = m0
|
634 |
+
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
635 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
636 |
+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
637 |
+
if self.config.algorithm_type == "dpmsolver++":
|
638 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
639 |
+
x_t = (
|
640 |
+
(sigma_t / sigma_s0) * sample
|
641 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
642 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
643 |
+
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
644 |
+
)
|
645 |
+
elif self.config.algorithm_type == "dpmsolver":
|
646 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
647 |
+
x_t = (
|
648 |
+
(alpha_t / alpha_s0) * sample
|
649 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
650 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
651 |
+
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
652 |
+
)
|
653 |
+
return x_t
|
654 |
+
|
655 |
+
def step(
|
656 |
+
self,
|
657 |
+
model_output: torch.FloatTensor,
|
658 |
+
timestep: int,
|
659 |
+
sample: torch.FloatTensor,
|
660 |
+
generator=None,
|
661 |
+
return_dict: bool = True,
|
662 |
+
w_ind_noise: float = 0.5,
|
663 |
+
) -> Union[SchedulerOutput, Tuple]:
|
664 |
+
"""
|
665 |
+
Step function propagating the sample with the multistep DPM-Solver.
|
666 |
+
|
667 |
+
Args:
|
668 |
+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
669 |
+
timestep (`int`): current discrete timestep in the diffusion chain.
|
670 |
+
sample (`torch.FloatTensor`):
|
671 |
+
current instance of sample being created by diffusion process.
|
672 |
+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
673 |
+
|
674 |
+
Returns:
|
675 |
+
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
|
676 |
+
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
677 |
+
|
678 |
+
"""
|
679 |
+
if self.num_inference_steps is None:
|
680 |
+
raise ValueError(
|
681 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
682 |
+
)
|
683 |
+
|
684 |
+
if isinstance(timestep, torch.Tensor):
|
685 |
+
timestep = timestep.to(self.timesteps.device)
|
686 |
+
step_index = (self.timesteps == timestep).nonzero()
|
687 |
+
if len(step_index) == 0:
|
688 |
+
step_index = len(self.timesteps) - 1
|
689 |
+
else:
|
690 |
+
step_index = step_index.item()
|
691 |
+
prev_timestep = (
|
692 |
+
0
|
693 |
+
if step_index == len(self.timesteps) - 1
|
694 |
+
else self.timesteps[step_index + 1]
|
695 |
+
)
|
696 |
+
lower_order_final = (
|
697 |
+
(step_index == len(self.timesteps) - 1)
|
698 |
+
and self.config.lower_order_final
|
699 |
+
and len(self.timesteps) < 15
|
700 |
+
)
|
701 |
+
lower_order_second = (
|
702 |
+
(step_index == len(self.timesteps) - 2)
|
703 |
+
and self.config.lower_order_final
|
704 |
+
and len(self.timesteps) < 15
|
705 |
+
)
|
706 |
+
|
707 |
+
model_output = self.convert_model_output(model_output, timestep, sample)
|
708 |
+
for i in range(self.config.solver_order - 1):
|
709 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
710 |
+
self.model_outputs[-1] = model_output
|
711 |
+
|
712 |
+
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
713 |
+
# noise = randn_tensor(
|
714 |
+
# model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
715 |
+
# )
|
716 |
+
common_noise = torch.randn(
|
717 |
+
model_output.shape[:2] + (1,) + model_output.shape[3:],
|
718 |
+
generator=generator,
|
719 |
+
device=model_output.device,
|
720 |
+
dtype=model_output.dtype,
|
721 |
+
) # common noise
|
722 |
+
ind_noise = randn_tensor(
|
723 |
+
model_output.shape,
|
724 |
+
generator=generator,
|
725 |
+
device=model_output.device,
|
726 |
+
dtype=model_output.dtype,
|
727 |
+
)
|
728 |
+
s = torch.tensor(
|
729 |
+
w_ind_noise, device=model_output.device, dtype=model_output.dtype
|
730 |
+
).to(device)
|
731 |
+
noise = torch.sqrt(1 - s) * common_noise + torch.sqrt(s) * ind_noise
|
732 |
+
|
733 |
+
else:
|
734 |
+
noise = None
|
735 |
+
|
736 |
+
if (
|
737 |
+
self.config.solver_order == 1
|
738 |
+
or self.lower_order_nums < 1
|
739 |
+
or lower_order_final
|
740 |
+
):
|
741 |
+
prev_sample = self.dpm_solver_first_order_update(
|
742 |
+
model_output, timestep, prev_timestep, sample, noise=noise
|
743 |
+
)
|
744 |
+
elif (
|
745 |
+
self.config.solver_order == 2
|
746 |
+
or self.lower_order_nums < 2
|
747 |
+
or lower_order_second
|
748 |
+
):
|
749 |
+
timestep_list = [self.timesteps[step_index - 1], timestep]
|
750 |
+
prev_sample = self.multistep_dpm_solver_second_order_update(
|
751 |
+
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
752 |
+
)
|
753 |
+
else:
|
754 |
+
timestep_list = [
|
755 |
+
self.timesteps[step_index - 2],
|
756 |
+
self.timesteps[step_index - 1],
|
757 |
+
timestep,
|
758 |
+
]
|
759 |
+
prev_sample = self.multistep_dpm_solver_third_order_update(
|
760 |
+
self.model_outputs, timestep_list, prev_timestep, sample
|
761 |
+
)
|
762 |
+
|
763 |
+
if self.lower_order_nums < self.config.solver_order:
|
764 |
+
self.lower_order_nums += 1
|
765 |
+
|
766 |
+
if not return_dict:
|
767 |
+
return (prev_sample,)
|
768 |
+
|
769 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
770 |
+
|
771 |
+
def scale_model_input(
|
772 |
+
self, sample: torch.FloatTensor, *args, **kwargs
|
773 |
+
) -> torch.FloatTensor:
|
774 |
+
"""
|
775 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
776 |
+
current timestep.
|
777 |
+
|
778 |
+
Args:
|
779 |
+
sample (`torch.FloatTensor`): input sample
|
780 |
+
|
781 |
+
Returns:
|
782 |
+
`torch.FloatTensor`: scaled input sample
|
783 |
+
"""
|
784 |
+
return sample
|
785 |
+
|
786 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
787 |
+
def add_noise(
|
788 |
+
self,
|
789 |
+
original_samples: torch.FloatTensor,
|
790 |
+
noise: torch.FloatTensor,
|
791 |
+
timesteps: torch.IntTensor,
|
792 |
+
) -> torch.FloatTensor:
|
793 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
794 |
+
alphas_cumprod = self.alphas_cumprod.to(
|
795 |
+
device=original_samples.device, dtype=original_samples.dtype
|
796 |
+
)
|
797 |
+
timesteps = timesteps.to(original_samples.device)
|
798 |
+
|
799 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
800 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
801 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
802 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
803 |
+
|
804 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
805 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
806 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
807 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
808 |
+
|
809 |
+
noisy_samples = (
|
810 |
+
sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
811 |
+
)
|
812 |
+
return noisy_samples
|
813 |
+
|
814 |
+
def __len__(self):
|
815 |
+
return self.config.num_train_timesteps
|
musev/schedulers/scheduling_euler_ancestral_discrete.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
|
24 |
+
from diffusers.utils import BaseOutput, logging
|
25 |
+
|
26 |
+
try:
|
27 |
+
from diffusers.utils import randn_tensor
|
28 |
+
except:
|
29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
30 |
+
from diffusers.schedulers.scheduling_utils import (
|
31 |
+
KarrasDiffusionSchedulers,
|
32 |
+
SchedulerMixin,
|
33 |
+
)
|
34 |
+
|
35 |
+
from ..utils.noise_util import video_fusion_noise
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
|
43 |
+
class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
|
44 |
+
"""
|
45 |
+
Output class for the scheduler's step function output.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
49 |
+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
50 |
+
denoising loop.
|
51 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
52 |
+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
53 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
54 |
+
"""
|
55 |
+
|
56 |
+
prev_sample: torch.FloatTensor
|
57 |
+
pred_original_sample: Optional[torch.FloatTensor] = None
|
58 |
+
|
59 |
+
|
60 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
61 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
62 |
+
"""
|
63 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
64 |
+
(1-beta) over time from t = [0,1].
|
65 |
+
|
66 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
67 |
+
to that part of the diffusion process.
|
68 |
+
|
69 |
+
|
70 |
+
Args:
|
71 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
72 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
73 |
+
prevent singularities.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
77 |
+
"""
|
78 |
+
|
79 |
+
def alpha_bar(time_step):
|
80 |
+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
81 |
+
|
82 |
+
betas = []
|
83 |
+
for i in range(num_diffusion_timesteps):
|
84 |
+
t1 = i / num_diffusion_timesteps
|
85 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
86 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
87 |
+
return torch.tensor(betas, dtype=torch.float32)
|
88 |
+
|
89 |
+
|
90 |
+
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
91 |
+
"""
|
92 |
+
Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
|
93 |
+
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
|
94 |
+
|
95 |
+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
96 |
+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
97 |
+
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
98 |
+
[`~SchedulerMixin.from_pretrained`] functions.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
102 |
+
beta_start (`float`): the starting `beta` value of inference.
|
103 |
+
beta_end (`float`): the final `beta` value.
|
104 |
+
beta_schedule (`str`):
|
105 |
+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
106 |
+
`linear` or `scaled_linear`.
|
107 |
+
trained_betas (`np.ndarray`, optional):
|
108 |
+
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
109 |
+
prediction_type (`str`, default `epsilon`, optional):
|
110 |
+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
111 |
+
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
112 |
+
https://imagen.research.google/video/paper.pdf)
|
113 |
+
|
114 |
+
"""
|
115 |
+
|
116 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
117 |
+
order = 1
|
118 |
+
|
119 |
+
@register_to_config
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
num_train_timesteps: int = 1000,
|
123 |
+
beta_start: float = 0.0001,
|
124 |
+
beta_end: float = 0.02,
|
125 |
+
beta_schedule: str = "linear",
|
126 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
127 |
+
prediction_type: str = "epsilon",
|
128 |
+
):
|
129 |
+
if trained_betas is not None:
|
130 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
131 |
+
elif beta_schedule == "linear":
|
132 |
+
self.betas = torch.linspace(
|
133 |
+
beta_start, beta_end, num_train_timesteps, dtype=torch.float32
|
134 |
+
)
|
135 |
+
elif beta_schedule == "scaled_linear":
|
136 |
+
# this schedule is very specific to the latent diffusion model.
|
137 |
+
self.betas = (
|
138 |
+
torch.linspace(
|
139 |
+
beta_start**0.5,
|
140 |
+
beta_end**0.5,
|
141 |
+
num_train_timesteps,
|
142 |
+
dtype=torch.float32,
|
143 |
+
)
|
144 |
+
** 2
|
145 |
+
)
|
146 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
147 |
+
# Glide cosine schedule
|
148 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
149 |
+
else:
|
150 |
+
raise NotImplementedError(
|
151 |
+
f"{beta_schedule} does is not implemented for {self.__class__}"
|
152 |
+
)
|
153 |
+
|
154 |
+
self.alphas = 1.0 - self.betas
|
155 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
156 |
+
|
157 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
158 |
+
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
159 |
+
self.sigmas = torch.from_numpy(sigmas)
|
160 |
+
|
161 |
+
# standard deviation of the initial noise distribution
|
162 |
+
self.init_noise_sigma = self.sigmas.max()
|
163 |
+
|
164 |
+
# setable values
|
165 |
+
self.num_inference_steps = None
|
166 |
+
timesteps = np.linspace(
|
167 |
+
0, num_train_timesteps - 1, num_train_timesteps, dtype=float
|
168 |
+
)[::-1].copy()
|
169 |
+
self.timesteps = torch.from_numpy(timesteps)
|
170 |
+
self.is_scale_input_called = False
|
171 |
+
|
172 |
+
def scale_model_input(
|
173 |
+
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
174 |
+
) -> torch.FloatTensor:
|
175 |
+
"""
|
176 |
+
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
sample (`torch.FloatTensor`): input sample
|
180 |
+
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
`torch.FloatTensor`: scaled input sample
|
184 |
+
"""
|
185 |
+
if isinstance(timestep, torch.Tensor):
|
186 |
+
timestep = timestep.to(self.timesteps.device)
|
187 |
+
step_index = (self.timesteps == timestep).nonzero().item()
|
188 |
+
sigma = self.sigmas[step_index]
|
189 |
+
sample = sample / ((sigma**2 + 1) ** 0.5)
|
190 |
+
self.is_scale_input_called = True
|
191 |
+
return sample
|
192 |
+
|
193 |
+
def set_timesteps(
|
194 |
+
self, num_inference_steps: int, device: Union[str, torch.device] = None
|
195 |
+
):
|
196 |
+
"""
|
197 |
+
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
num_inference_steps (`int`):
|
201 |
+
the number of diffusion steps used when generating samples with a pre-trained model.
|
202 |
+
device (`str` or `torch.device`, optional):
|
203 |
+
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
204 |
+
"""
|
205 |
+
self.num_inference_steps = num_inference_steps
|
206 |
+
|
207 |
+
timesteps = np.linspace(
|
208 |
+
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float
|
209 |
+
)[::-1].copy()
|
210 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
211 |
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
212 |
+
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
213 |
+
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
214 |
+
if str(device).startswith("mps"):
|
215 |
+
# mps does not support float64
|
216 |
+
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
217 |
+
else:
|
218 |
+
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
219 |
+
|
220 |
+
def step(
|
221 |
+
self,
|
222 |
+
model_output: torch.FloatTensor,
|
223 |
+
timestep: Union[float, torch.FloatTensor],
|
224 |
+
sample: torch.FloatTensor,
|
225 |
+
generator: Optional[torch.Generator] = None,
|
226 |
+
return_dict: bool = True,
|
227 |
+
w_ind_noise: float = 0.5,
|
228 |
+
noise_type: str = "random",
|
229 |
+
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
230 |
+
"""
|
231 |
+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
232 |
+
process from the learned model outputs (most often the predicted noise).
|
233 |
+
|
234 |
+
Args:
|
235 |
+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
236 |
+
timestep (`float`): current timestep in the diffusion chain.
|
237 |
+
sample (`torch.FloatTensor`):
|
238 |
+
current instance of sample being created by diffusion process.
|
239 |
+
generator (`torch.Generator`, optional): Random number generator.
|
240 |
+
return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
244 |
+
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
|
245 |
+
a `tuple`. When returning a tuple, the first element is the sample tensor.
|
246 |
+
|
247 |
+
"""
|
248 |
+
|
249 |
+
if (
|
250 |
+
isinstance(timestep, int)
|
251 |
+
or isinstance(timestep, torch.IntTensor)
|
252 |
+
or isinstance(timestep, torch.LongTensor)
|
253 |
+
):
|
254 |
+
raise ValueError(
|
255 |
+
(
|
256 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
257 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
258 |
+
" one of the `scheduler.timesteps` as a timestep."
|
259 |
+
),
|
260 |
+
)
|
261 |
+
|
262 |
+
if not self.is_scale_input_called:
|
263 |
+
logger.warning(
|
264 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
265 |
+
"See `StableDiffusionPipeline` for a usage example."
|
266 |
+
)
|
267 |
+
|
268 |
+
if isinstance(timestep, torch.Tensor):
|
269 |
+
timestep = timestep.to(self.timesteps.device)
|
270 |
+
|
271 |
+
step_index = (self.timesteps == timestep).nonzero().item()
|
272 |
+
sigma = self.sigmas[step_index]
|
273 |
+
|
274 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
275 |
+
if self.config.prediction_type == "epsilon":
|
276 |
+
pred_original_sample = sample - sigma * model_output
|
277 |
+
elif self.config.prediction_type == "v_prediction":
|
278 |
+
# * c_out + input * c_skip
|
279 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
|
280 |
+
sample / (sigma**2 + 1)
|
281 |
+
)
|
282 |
+
elif self.config.prediction_type == "sample":
|
283 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
284 |
+
else:
|
285 |
+
raise ValueError(
|
286 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
287 |
+
)
|
288 |
+
|
289 |
+
sigma_from = self.sigmas[step_index]
|
290 |
+
sigma_to = self.sigmas[step_index + 1]
|
291 |
+
sigma_up = (
|
292 |
+
sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
|
293 |
+
) ** 0.5
|
294 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
295 |
+
|
296 |
+
# 2. Convert to an ODE derivative
|
297 |
+
derivative = (sample - pred_original_sample) / sigma
|
298 |
+
|
299 |
+
dt = sigma_down - sigma
|
300 |
+
|
301 |
+
prev_sample = sample + derivative * dt
|
302 |
+
|
303 |
+
device = model_output.device
|
304 |
+
if noise_type == "random":
|
305 |
+
noise = randn_tensor(
|
306 |
+
model_output.shape,
|
307 |
+
dtype=model_output.dtype,
|
308 |
+
device=device,
|
309 |
+
generator=generator,
|
310 |
+
)
|
311 |
+
elif noise_type == "video_fusion":
|
312 |
+
noise = video_fusion_noise(
|
313 |
+
model_output, w_ind_noise=w_ind_noise, generator=generator
|
314 |
+
)
|
315 |
+
|
316 |
+
prev_sample = prev_sample + noise * sigma_up
|
317 |
+
|
318 |
+
if not return_dict:
|
319 |
+
return (prev_sample,)
|
320 |
+
|
321 |
+
return EulerAncestralDiscreteSchedulerOutput(
|
322 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
323 |
+
)
|
324 |
+
|
325 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
326 |
+
def add_noise(
|
327 |
+
self,
|
328 |
+
original_samples: torch.FloatTensor,
|
329 |
+
noise: torch.FloatTensor,
|
330 |
+
timesteps: torch.FloatTensor,
|
331 |
+
) -> torch.FloatTensor:
|
332 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
333 |
+
sigmas = self.sigmas.to(
|
334 |
+
device=original_samples.device, dtype=original_samples.dtype
|
335 |
+
)
|
336 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
337 |
+
# mps does not support float64
|
338 |
+
schedule_timesteps = self.timesteps.to(
|
339 |
+
original_samples.device, dtype=torch.float32
|
340 |
+
)
|
341 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
342 |
+
else:
|
343 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
344 |
+
timesteps = timesteps.to(original_samples.device)
|
345 |
+
|
346 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
347 |
+
|
348 |
+
sigma = sigmas[step_indices].flatten()
|
349 |
+
while len(sigma.shape) < len(original_samples.shape):
|
350 |
+
sigma = sigma.unsqueeze(-1)
|
351 |
+
|
352 |
+
noisy_samples = original_samples + noise * sigma
|
353 |
+
return noisy_samples
|
354 |
+
|
355 |
+
def __len__(self):
|
356 |
+
return self.config.num_train_timesteps
|
musev/schedulers/scheduling_euler_discrete.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
import numpy as np
|
6 |
+
from numpy import ndarray
|
7 |
+
import torch
|
8 |
+
from torch import Generator, FloatTensor
|
9 |
+
from diffusers.schedulers.scheduling_euler_discrete import (
|
10 |
+
EulerDiscreteScheduler as DiffusersEulerDiscreteScheduler,
|
11 |
+
EulerDiscreteSchedulerOutput,
|
12 |
+
)
|
13 |
+
from diffusers.utils.torch_utils import randn_tensor
|
14 |
+
|
15 |
+
from ..utils.noise_util import video_fusion_noise
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
18 |
+
|
19 |
+
|
20 |
+
class EulerDiscreteScheduler(DiffusersEulerDiscreteScheduler):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
num_train_timesteps: int = 1000,
|
24 |
+
beta_start: float = 0.0001,
|
25 |
+
beta_end: float = 0.02,
|
26 |
+
beta_schedule: str = "linear",
|
27 |
+
trained_betas: ndarray | List[float] | None = None,
|
28 |
+
prediction_type: str = "epsilon",
|
29 |
+
interpolation_type: str = "linear",
|
30 |
+
use_karras_sigmas: bool | None = False,
|
31 |
+
timestep_spacing: str = "linspace",
|
32 |
+
steps_offset: int = 0,
|
33 |
+
):
|
34 |
+
super().__init__(
|
35 |
+
num_train_timesteps,
|
36 |
+
beta_start,
|
37 |
+
beta_end,
|
38 |
+
beta_schedule,
|
39 |
+
trained_betas,
|
40 |
+
prediction_type,
|
41 |
+
interpolation_type,
|
42 |
+
use_karras_sigmas,
|
43 |
+
timestep_spacing,
|
44 |
+
steps_offset,
|
45 |
+
)
|
46 |
+
|
47 |
+
def step(
|
48 |
+
self,
|
49 |
+
model_output: torch.FloatTensor,
|
50 |
+
timestep: Union[float, torch.FloatTensor],
|
51 |
+
sample: torch.FloatTensor,
|
52 |
+
s_churn: float = 0.0,
|
53 |
+
s_tmin: float = 0.0,
|
54 |
+
s_tmax: float = float("inf"),
|
55 |
+
s_noise: float = 1.0,
|
56 |
+
generator: Optional[torch.Generator] = None,
|
57 |
+
return_dict: bool = True,
|
58 |
+
w_ind_noise: float = 0.5,
|
59 |
+
noise_type: str = "random",
|
60 |
+
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
|
61 |
+
"""
|
62 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
63 |
+
process from the learned model outputs (most often the predicted noise).
|
64 |
+
|
65 |
+
Args:
|
66 |
+
model_output (`torch.FloatTensor`):
|
67 |
+
The direct output from learned diffusion model.
|
68 |
+
timestep (`float`):
|
69 |
+
The current discrete timestep in the diffusion chain.
|
70 |
+
sample (`torch.FloatTensor`):
|
71 |
+
A current instance of a sample created by the diffusion process.
|
72 |
+
s_churn (`float`):
|
73 |
+
s_tmin (`float`):
|
74 |
+
s_tmax (`float`):
|
75 |
+
s_noise (`float`, defaults to 1.0):
|
76 |
+
Scaling factor for noise added to the sample.
|
77 |
+
generator (`torch.Generator`, *optional*):
|
78 |
+
A random number generator.
|
79 |
+
return_dict (`bool`):
|
80 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
81 |
+
tuple.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
85 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
86 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
87 |
+
"""
|
88 |
+
|
89 |
+
if (
|
90 |
+
isinstance(timestep, int)
|
91 |
+
or isinstance(timestep, torch.IntTensor)
|
92 |
+
or isinstance(timestep, torch.LongTensor)
|
93 |
+
):
|
94 |
+
raise ValueError(
|
95 |
+
(
|
96 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
97 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
98 |
+
" one of the `scheduler.timesteps` as a timestep."
|
99 |
+
),
|
100 |
+
)
|
101 |
+
|
102 |
+
if not self.is_scale_input_called:
|
103 |
+
logger.warning(
|
104 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
105 |
+
"See `StableDiffusionPipeline` for a usage example."
|
106 |
+
)
|
107 |
+
|
108 |
+
if self.step_index is None:
|
109 |
+
self._init_step_index(timestep)
|
110 |
+
|
111 |
+
sigma = self.sigmas[self.step_index]
|
112 |
+
|
113 |
+
gamma = (
|
114 |
+
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
|
115 |
+
if s_tmin <= sigma <= s_tmax
|
116 |
+
else 0.0
|
117 |
+
)
|
118 |
+
device = model_output.device
|
119 |
+
|
120 |
+
if noise_type == "random":
|
121 |
+
noise = randn_tensor(
|
122 |
+
model_output.shape,
|
123 |
+
dtype=model_output.dtype,
|
124 |
+
device=device,
|
125 |
+
generator=generator,
|
126 |
+
)
|
127 |
+
elif noise_type == "video_fusion":
|
128 |
+
noise = video_fusion_noise(
|
129 |
+
model_output, w_ind_noise=w_ind_noise, generator=generator
|
130 |
+
)
|
131 |
+
|
132 |
+
eps = noise * s_noise
|
133 |
+
sigma_hat = sigma * (gamma + 1)
|
134 |
+
|
135 |
+
if gamma > 0:
|
136 |
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
137 |
+
|
138 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
139 |
+
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
140 |
+
# backwards compatibility
|
141 |
+
if (
|
142 |
+
self.config.prediction_type == "original_sample"
|
143 |
+
or self.config.prediction_type == "sample"
|
144 |
+
):
|
145 |
+
pred_original_sample = model_output
|
146 |
+
elif self.config.prediction_type == "epsilon":
|
147 |
+
pred_original_sample = sample - sigma_hat * model_output
|
148 |
+
elif self.config.prediction_type == "v_prediction":
|
149 |
+
# * c_out + input * c_skip
|
150 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
|
151 |
+
sample / (sigma**2 + 1)
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
raise ValueError(
|
155 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
156 |
+
)
|
157 |
+
|
158 |
+
# 2. Convert to an ODE derivative
|
159 |
+
derivative = (sample - pred_original_sample) / sigma_hat
|
160 |
+
|
161 |
+
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
162 |
+
|
163 |
+
prev_sample = sample + derivative * dt
|
164 |
+
|
165 |
+
# upon completion increase step index by one
|
166 |
+
self._step_index += 1
|
167 |
+
|
168 |
+
if not return_dict:
|
169 |
+
return (prev_sample,)
|
170 |
+
|
171 |
+
return EulerDiscreteSchedulerOutput(
|
172 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
173 |
+
)
|
174 |
+
|
175 |
+
def step_bk(
|
176 |
+
self,
|
177 |
+
model_output: FloatTensor,
|
178 |
+
timestep: float | FloatTensor,
|
179 |
+
sample: FloatTensor,
|
180 |
+
s_churn: float = 0,
|
181 |
+
s_tmin: float = 0,
|
182 |
+
s_tmax: float = float("inf"),
|
183 |
+
s_noise: float = 1,
|
184 |
+
generator: Generator | None = None,
|
185 |
+
return_dict: bool = True,
|
186 |
+
w_ind_noise: float = 0.5,
|
187 |
+
noise_type: str = "random",
|
188 |
+
) -> EulerDiscreteSchedulerOutput | Tuple:
|
189 |
+
"""
|
190 |
+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
191 |
+
process from the learned model outputs (most often the predicted noise).
|
192 |
+
|
193 |
+
Args:
|
194 |
+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
195 |
+
timestep (`float`): current timestep in the diffusion chain.
|
196 |
+
sample (`torch.FloatTensor`):
|
197 |
+
current instance of sample being created by diffusion process.
|
198 |
+
s_churn (`float`)
|
199 |
+
s_tmin (`float`)
|
200 |
+
s_tmax (`float`)
|
201 |
+
s_noise (`float`)
|
202 |
+
generator (`torch.Generator`, optional): Random number generator.
|
203 |
+
return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`:
|
207 |
+
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a
|
208 |
+
`tuple`. When returning a tuple, the first element is the sample tensor.
|
209 |
+
|
210 |
+
"""
|
211 |
+
|
212 |
+
if (
|
213 |
+
isinstance(timestep, int)
|
214 |
+
or isinstance(timestep, torch.IntTensor)
|
215 |
+
or isinstance(timestep, torch.LongTensor)
|
216 |
+
):
|
217 |
+
raise ValueError(
|
218 |
+
(
|
219 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
220 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
221 |
+
" one of the `scheduler.timesteps` as a timestep."
|
222 |
+
),
|
223 |
+
)
|
224 |
+
|
225 |
+
if not self.is_scale_input_called:
|
226 |
+
logger.warning(
|
227 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
228 |
+
"See `StableDiffusionPipeline` for a usage example."
|
229 |
+
)
|
230 |
+
|
231 |
+
if isinstance(timestep, torch.Tensor):
|
232 |
+
timestep = timestep.to(self.timesteps.device)
|
233 |
+
|
234 |
+
step_index = (self.timesteps == timestep).nonzero().item()
|
235 |
+
sigma = self.sigmas[step_index]
|
236 |
+
|
237 |
+
gamma = (
|
238 |
+
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
|
239 |
+
if s_tmin <= sigma <= s_tmax
|
240 |
+
else 0.0
|
241 |
+
)
|
242 |
+
|
243 |
+
device = model_output.device
|
244 |
+
if noise_type == "random":
|
245 |
+
noise = randn_tensor(
|
246 |
+
model_output.shape,
|
247 |
+
dtype=model_output.dtype,
|
248 |
+
device=device,
|
249 |
+
generator=generator,
|
250 |
+
)
|
251 |
+
elif noise_type == "video_fusion":
|
252 |
+
noise = video_fusion_noise(
|
253 |
+
model_output, w_ind_noise=w_ind_noise, generator=generator
|
254 |
+
)
|
255 |
+
eps = noise * s_noise
|
256 |
+
sigma_hat = sigma * (gamma + 1)
|
257 |
+
|
258 |
+
if gamma > 0:
|
259 |
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
260 |
+
|
261 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
262 |
+
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
263 |
+
# backwards compatibility
|
264 |
+
if (
|
265 |
+
self.config.prediction_type == "original_sample"
|
266 |
+
or self.config.prediction_type == "sample"
|
267 |
+
):
|
268 |
+
pred_original_sample = model_output
|
269 |
+
elif self.config.prediction_type == "epsilon":
|
270 |
+
pred_original_sample = sample - sigma_hat * model_output
|
271 |
+
elif self.config.prediction_type == "v_prediction":
|
272 |
+
# * c_out + input * c_skip
|
273 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
|
274 |
+
sample / (sigma**2 + 1)
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
raise ValueError(
|
278 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
279 |
+
)
|
280 |
+
|
281 |
+
# 2. Convert to an ODE derivative
|
282 |
+
derivative = (sample - pred_original_sample) / sigma_hat
|
283 |
+
|
284 |
+
dt = self.sigmas[step_index + 1] - sigma_hat
|
285 |
+
|
286 |
+
prev_sample = sample + derivative * dt
|
287 |
+
|
288 |
+
if not return_dict:
|
289 |
+
return (prev_sample,)
|
290 |
+
|
291 |
+
return EulerDiscreteSchedulerOutput(
|
292 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
293 |
+
)
|
musev/schedulers/scheduling_lcm.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
from __future__ import annotations
|
18 |
+
|
19 |
+
import math
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from numpy import ndarray
|
26 |
+
|
27 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
28 |
+
from diffusers.utils import BaseOutput, logging
|
29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
30 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
31 |
+
from diffusers.schedulers.scheduling_lcm import (
|
32 |
+
LCMSchedulerOutput,
|
33 |
+
betas_for_alpha_bar,
|
34 |
+
rescale_zero_terminal_snr,
|
35 |
+
LCMScheduler as DiffusersLCMScheduler,
|
36 |
+
)
|
37 |
+
from ..utils.noise_util import video_fusion_noise
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40 |
+
|
41 |
+
|
42 |
+
class LCMScheduler(DiffusersLCMScheduler):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
num_train_timesteps: int = 1000,
|
46 |
+
beta_start: float = 0.00085,
|
47 |
+
beta_end: float = 0.012,
|
48 |
+
beta_schedule: str = "scaled_linear",
|
49 |
+
trained_betas: ndarray | List[float] | None = None,
|
50 |
+
original_inference_steps: int = 50,
|
51 |
+
clip_sample: bool = False,
|
52 |
+
clip_sample_range: float = 1,
|
53 |
+
set_alpha_to_one: bool = True,
|
54 |
+
steps_offset: int = 0,
|
55 |
+
prediction_type: str = "epsilon",
|
56 |
+
thresholding: bool = False,
|
57 |
+
dynamic_thresholding_ratio: float = 0.995,
|
58 |
+
sample_max_value: float = 1,
|
59 |
+
timestep_spacing: str = "leading",
|
60 |
+
timestep_scaling: float = 10,
|
61 |
+
rescale_betas_zero_snr: bool = False,
|
62 |
+
):
|
63 |
+
super().__init__(
|
64 |
+
num_train_timesteps,
|
65 |
+
beta_start,
|
66 |
+
beta_end,
|
67 |
+
beta_schedule,
|
68 |
+
trained_betas,
|
69 |
+
original_inference_steps,
|
70 |
+
clip_sample,
|
71 |
+
clip_sample_range,
|
72 |
+
set_alpha_to_one,
|
73 |
+
steps_offset,
|
74 |
+
prediction_type,
|
75 |
+
thresholding,
|
76 |
+
dynamic_thresholding_ratio,
|
77 |
+
sample_max_value,
|
78 |
+
timestep_spacing,
|
79 |
+
timestep_scaling,
|
80 |
+
rescale_betas_zero_snr,
|
81 |
+
)
|
82 |
+
|
83 |
+
def step(
|
84 |
+
self,
|
85 |
+
model_output: torch.FloatTensor,
|
86 |
+
timestep: int,
|
87 |
+
sample: torch.FloatTensor,
|
88 |
+
generator: Optional[torch.Generator] = None,
|
89 |
+
return_dict: bool = True,
|
90 |
+
w_ind_noise: float = 0.5,
|
91 |
+
noise_type: str = "random",
|
92 |
+
) -> Union[LCMSchedulerOutput, Tuple]:
|
93 |
+
"""
|
94 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
95 |
+
process from the learned model outputs (most often the predicted noise).
|
96 |
+
|
97 |
+
Args:
|
98 |
+
model_output (`torch.FloatTensor`):
|
99 |
+
The direct output from learned diffusion model.
|
100 |
+
timestep (`float`):
|
101 |
+
The current discrete timestep in the diffusion chain.
|
102 |
+
sample (`torch.FloatTensor`):
|
103 |
+
A current instance of a sample created by the diffusion process.
|
104 |
+
generator (`torch.Generator`, *optional*):
|
105 |
+
A random number generator.
|
106 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
107 |
+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
108 |
+
Returns:
|
109 |
+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
110 |
+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
111 |
+
tuple is returned where the first element is the sample tensor.
|
112 |
+
"""
|
113 |
+
if self.num_inference_steps is None:
|
114 |
+
raise ValueError(
|
115 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
116 |
+
)
|
117 |
+
|
118 |
+
if self.step_index is None:
|
119 |
+
self._init_step_index(timestep)
|
120 |
+
|
121 |
+
# 1. get previous step value
|
122 |
+
prev_step_index = self.step_index + 1
|
123 |
+
if prev_step_index < len(self.timesteps):
|
124 |
+
prev_timestep = self.timesteps[prev_step_index]
|
125 |
+
else:
|
126 |
+
prev_timestep = timestep
|
127 |
+
|
128 |
+
# 2. compute alphas, betas
|
129 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
130 |
+
alpha_prod_t_prev = (
|
131 |
+
self.alphas_cumprod[prev_timestep]
|
132 |
+
if prev_timestep >= 0
|
133 |
+
else self.final_alpha_cumprod
|
134 |
+
)
|
135 |
+
|
136 |
+
beta_prod_t = 1 - alpha_prod_t
|
137 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
138 |
+
|
139 |
+
# 3. Get scalings for boundary conditions
|
140 |
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
141 |
+
|
142 |
+
# 4. Compute the predicted original sample x_0 based on the model parameterization
|
143 |
+
if self.config.prediction_type == "epsilon": # noise-prediction
|
144 |
+
predicted_original_sample = (
|
145 |
+
sample - beta_prod_t.sqrt() * model_output
|
146 |
+
) / alpha_prod_t.sqrt()
|
147 |
+
elif self.config.prediction_type == "sample": # x-prediction
|
148 |
+
predicted_original_sample = model_output
|
149 |
+
elif self.config.prediction_type == "v_prediction": # v-prediction
|
150 |
+
predicted_original_sample = (
|
151 |
+
alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
raise ValueError(
|
155 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
156 |
+
" `v_prediction` for `LCMScheduler`."
|
157 |
+
)
|
158 |
+
|
159 |
+
# 5. Clip or threshold "predicted x_0"
|
160 |
+
if self.config.thresholding:
|
161 |
+
predicted_original_sample = self._threshold_sample(
|
162 |
+
predicted_original_sample
|
163 |
+
)
|
164 |
+
elif self.config.clip_sample:
|
165 |
+
predicted_original_sample = predicted_original_sample.clamp(
|
166 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
167 |
+
)
|
168 |
+
|
169 |
+
# 6. Denoise model output using boundary conditions
|
170 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
171 |
+
|
172 |
+
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
|
173 |
+
# Noise is not used on the final timestep of the timestep schedule.
|
174 |
+
# This also means that noise is not used for one-step sampling.
|
175 |
+
device = model_output.device
|
176 |
+
|
177 |
+
if self.step_index != self.num_inference_steps - 1:
|
178 |
+
if noise_type == "random":
|
179 |
+
noise = randn_tensor(
|
180 |
+
model_output.shape,
|
181 |
+
dtype=model_output.dtype,
|
182 |
+
device=device,
|
183 |
+
generator=generator,
|
184 |
+
)
|
185 |
+
elif noise_type == "video_fusion":
|
186 |
+
noise = video_fusion_noise(
|
187 |
+
model_output, w_ind_noise=w_ind_noise, generator=generator
|
188 |
+
)
|
189 |
+
prev_sample = (
|
190 |
+
alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
prev_sample = denoised
|
194 |
+
|
195 |
+
# upon completion increase step index by one
|
196 |
+
self._step_index += 1
|
197 |
+
|
198 |
+
if not return_dict:
|
199 |
+
return (prev_sample, denoised)
|
200 |
+
|
201 |
+
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
|
202 |
+
|
203 |
+
def step_bk(
|
204 |
+
self,
|
205 |
+
model_output: torch.FloatTensor,
|
206 |
+
timestep: int,
|
207 |
+
sample: torch.FloatTensor,
|
208 |
+
generator: Optional[torch.Generator] = None,
|
209 |
+
return_dict: bool = True,
|
210 |
+
) -> Union[LCMSchedulerOutput, Tuple]:
|
211 |
+
"""
|
212 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
213 |
+
process from the learned model outputs (most often the predicted noise).
|
214 |
+
|
215 |
+
Args:
|
216 |
+
model_output (`torch.FloatTensor`):
|
217 |
+
The direct output from learned diffusion model.
|
218 |
+
timestep (`float`):
|
219 |
+
The current discrete timestep in the diffusion chain.
|
220 |
+
sample (`torch.FloatTensor`):
|
221 |
+
A current instance of a sample created by the diffusion process.
|
222 |
+
generator (`torch.Generator`, *optional*):
|
223 |
+
A random number generator.
|
224 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
225 |
+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
226 |
+
Returns:
|
227 |
+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
228 |
+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
229 |
+
tuple is returned where the first element is the sample tensor.
|
230 |
+
"""
|
231 |
+
if self.num_inference_steps is None:
|
232 |
+
raise ValueError(
|
233 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
234 |
+
)
|
235 |
+
|
236 |
+
if self.step_index is None:
|
237 |
+
self._init_step_index(timestep)
|
238 |
+
|
239 |
+
# 1. get previous step value
|
240 |
+
prev_step_index = self.step_index + 1
|
241 |
+
if prev_step_index < len(self.timesteps):
|
242 |
+
prev_timestep = self.timesteps[prev_step_index]
|
243 |
+
else:
|
244 |
+
prev_timestep = timestep
|
245 |
+
|
246 |
+
# 2. compute alphas, betas
|
247 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
248 |
+
alpha_prod_t_prev = (
|
249 |
+
self.alphas_cumprod[prev_timestep]
|
250 |
+
if prev_timestep >= 0
|
251 |
+
else self.final_alpha_cumprod
|
252 |
+
)
|
253 |
+
|
254 |
+
beta_prod_t = 1 - alpha_prod_t
|
255 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
256 |
+
|
257 |
+
# 3. Get scalings for boundary conditions
|
258 |
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
259 |
+
|
260 |
+
# 4. Compute the predicted original sample x_0 based on the model parameterization
|
261 |
+
if self.config.prediction_type == "epsilon": # noise-prediction
|
262 |
+
predicted_original_sample = (
|
263 |
+
sample - beta_prod_t.sqrt() * model_output
|
264 |
+
) / alpha_prod_t.sqrt()
|
265 |
+
elif self.config.prediction_type == "sample": # x-prediction
|
266 |
+
predicted_original_sample = model_output
|
267 |
+
elif self.config.prediction_type == "v_prediction": # v-prediction
|
268 |
+
predicted_original_sample = (
|
269 |
+
alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
raise ValueError(
|
273 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
274 |
+
" `v_prediction` for `LCMScheduler`."
|
275 |
+
)
|
276 |
+
|
277 |
+
# 5. Clip or threshold "predicted x_0"
|
278 |
+
if self.config.thresholding:
|
279 |
+
predicted_original_sample = self._threshold_sample(
|
280 |
+
predicted_original_sample
|
281 |
+
)
|
282 |
+
elif self.config.clip_sample:
|
283 |
+
predicted_original_sample = predicted_original_sample.clamp(
|
284 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
285 |
+
)
|
286 |
+
|
287 |
+
# 6. Denoise model output using boundary conditions
|
288 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
289 |
+
|
290 |
+
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
|
291 |
+
# Noise is not used on the final timestep of the timestep schedule.
|
292 |
+
# This also means that noise is not used for one-step sampling.
|
293 |
+
if self.step_index != self.num_inference_steps - 1:
|
294 |
+
noise = randn_tensor(
|
295 |
+
model_output.shape,
|
296 |
+
generator=generator,
|
297 |
+
device=model_output.device,
|
298 |
+
dtype=denoised.dtype,
|
299 |
+
)
|
300 |
+
prev_sample = (
|
301 |
+
alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
prev_sample = denoised
|
305 |
+
|
306 |
+
# upon completion increase step index by one
|
307 |
+
self._step_index += 1
|
308 |
+
|
309 |
+
if not return_dict:
|
310 |
+
return (prev_sample, denoised)
|
311 |
+
|
312 |
+
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
|
musev/utils/__init__.py
ADDED
File without changes
|
musev/utils/attention_util.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union, Literal
|
2 |
+
|
3 |
+
from einops import repeat
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def get_diags_indices(
|
9 |
+
shape: Union[int, Tuple[int, int]], k_min: int = 0, k_max: int = 0
|
10 |
+
):
|
11 |
+
if isinstance(shape, int):
|
12 |
+
shape = (shape, shape)
|
13 |
+
rows, cols = np.indices(shape)
|
14 |
+
diag = cols - rows
|
15 |
+
return np.where((diag >= k_min) & (diag <= k_max))
|
16 |
+
|
17 |
+
|
18 |
+
def generate_mask_from_indices(
|
19 |
+
shape: Tuple[int, int],
|
20 |
+
indices: Tuple[np.ndarray, np.ndarray],
|
21 |
+
big_value: float = 0,
|
22 |
+
small_value: float = -1e9,
|
23 |
+
):
|
24 |
+
matrix = np.ones(shape) * small_value
|
25 |
+
matrix[indices] = big_value
|
26 |
+
return matrix
|
27 |
+
|
28 |
+
|
29 |
+
def generate_sparse_causcal_attn_mask(
|
30 |
+
batch_size: int,
|
31 |
+
n: int,
|
32 |
+
n_near: int = 1,
|
33 |
+
big_value: float = 0,
|
34 |
+
small_value: float = -1e9,
|
35 |
+
out_type: Literal["torch", "numpy"] = "numpy",
|
36 |
+
expand: int = 1,
|
37 |
+
) -> np.ndarray:
|
38 |
+
"""generate b (n expand) (n expand) mask,
|
39 |
+
where value of diag (0<=<=n_near) and first column of shape mat (n n) is set as big_value, others as small value
|
40 |
+
expand的概念:
|
41 |
+
attn 是 b n d 时,mask 是 b n n, 当 attn 是 b (expand n) d 时, mask 是 b (n expand) (n expand)
|
42 |
+
Args:
|
43 |
+
batch_size (int): _description_
|
44 |
+
n (int): _description_
|
45 |
+
n_near (int, optional): _description_. Defaults to 1.
|
46 |
+
big_value (float, optional): _description_. Defaults to 0.
|
47 |
+
small_value (float, optional): _description_. Defaults to -1e9.
|
48 |
+
out_type (Literal["torch", "numpy"], optional): _description_. Defaults to "numpy".
|
49 |
+
expand (int, optional): _description_. Defaults to 1.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
np.ndarray: _description_
|
53 |
+
"""
|
54 |
+
shape = (n, n)
|
55 |
+
diag_indices = get_diags_indices(n, k_min=-n_near, k_max=0)
|
56 |
+
first_column = (np.arange(n), np.zeros(n).astype(np.int))
|
57 |
+
indices = (
|
58 |
+
np.concatenate([diag_indices[0], first_column[0]]),
|
59 |
+
np.concatenate([diag_indices[1], first_column[1]]),
|
60 |
+
)
|
61 |
+
mask = generate_mask_from_indices(
|
62 |
+
shape=shape, indices=indices, big_value=big_value, small_value=small_value
|
63 |
+
)
|
64 |
+
mask = repeat(mask, "m n-> b m n", b=batch_size)
|
65 |
+
if expand > 1:
|
66 |
+
mask = repeat(
|
67 |
+
mask,
|
68 |
+
"b m n -> b (m d1) (n d2)",
|
69 |
+
d1=expand,
|
70 |
+
d2=expand,
|
71 |
+
)
|
72 |
+
if out_type == "torch":
|
73 |
+
mask = torch.from_numpy(mask)
|
74 |
+
return mask
|
musev/utils/convert_from_ckpt.py
ADDED
@@ -0,0 +1,963 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Conversion script for the Stable Diffusion checkpoints."""
|
16 |
+
|
17 |
+
import re
|
18 |
+
from io import BytesIO
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
import requests
|
22 |
+
import torch
|
23 |
+
from transformers import (
|
24 |
+
AutoFeatureExtractor,
|
25 |
+
BertTokenizerFast,
|
26 |
+
CLIPImageProcessor,
|
27 |
+
CLIPTextModel,
|
28 |
+
CLIPTextModelWithProjection,
|
29 |
+
CLIPTokenizer,
|
30 |
+
CLIPVisionConfig,
|
31 |
+
CLIPVisionModelWithProjection,
|
32 |
+
)
|
33 |
+
|
34 |
+
from diffusers.models import (
|
35 |
+
AutoencoderKL,
|
36 |
+
PriorTransformer,
|
37 |
+
UNet2DConditionModel,
|
38 |
+
)
|
39 |
+
from diffusers.schedulers import (
|
40 |
+
DDIMScheduler,
|
41 |
+
DDPMScheduler,
|
42 |
+
DPMSolverMultistepScheduler,
|
43 |
+
EulerAncestralDiscreteScheduler,
|
44 |
+
EulerDiscreteScheduler,
|
45 |
+
HeunDiscreteScheduler,
|
46 |
+
LMSDiscreteScheduler,
|
47 |
+
PNDMScheduler,
|
48 |
+
UnCLIPScheduler,
|
49 |
+
)
|
50 |
+
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
51 |
+
|
52 |
+
|
53 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
54 |
+
"""
|
55 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
56 |
+
"""
|
57 |
+
if n_shave_prefix_segments >= 0:
|
58 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
59 |
+
else:
|
60 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
61 |
+
|
62 |
+
|
63 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
64 |
+
"""
|
65 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
66 |
+
"""
|
67 |
+
mapping = []
|
68 |
+
for old_item in old_list:
|
69 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
70 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
71 |
+
|
72 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
73 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
74 |
+
|
75 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
76 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
77 |
+
|
78 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
79 |
+
|
80 |
+
mapping.append({"old": old_item, "new": new_item})
|
81 |
+
|
82 |
+
return mapping
|
83 |
+
|
84 |
+
|
85 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
86 |
+
"""
|
87 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
88 |
+
"""
|
89 |
+
mapping = []
|
90 |
+
for old_item in old_list:
|
91 |
+
new_item = old_item
|
92 |
+
|
93 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
94 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
95 |
+
|
96 |
+
mapping.append({"old": old_item, "new": new_item})
|
97 |
+
|
98 |
+
return mapping
|
99 |
+
|
100 |
+
|
101 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
102 |
+
"""
|
103 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
104 |
+
"""
|
105 |
+
mapping = []
|
106 |
+
for old_item in old_list:
|
107 |
+
new_item = old_item
|
108 |
+
|
109 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
110 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
111 |
+
|
112 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
113 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
114 |
+
|
115 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
116 |
+
|
117 |
+
mapping.append({"old": old_item, "new": new_item})
|
118 |
+
|
119 |
+
return mapping
|
120 |
+
|
121 |
+
|
122 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
123 |
+
"""
|
124 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
125 |
+
"""
|
126 |
+
mapping = []
|
127 |
+
for old_item in old_list:
|
128 |
+
new_item = old_item
|
129 |
+
|
130 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
131 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("q.weight", "to_q.weight")
|
134 |
+
new_item = new_item.replace("q.bias", "to_q.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("k.weight", "to_k.weight")
|
137 |
+
new_item = new_item.replace("k.bias", "to_k.bias")
|
138 |
+
|
139 |
+
new_item = new_item.replace("v.weight", "to_v.weight")
|
140 |
+
new_item = new_item.replace("v.bias", "to_v.bias")
|
141 |
+
|
142 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
143 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
144 |
+
|
145 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
146 |
+
|
147 |
+
mapping.append({"old": old_item, "new": new_item})
|
148 |
+
|
149 |
+
return mapping
|
150 |
+
|
151 |
+
|
152 |
+
def assign_to_checkpoint(
|
153 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
157 |
+
attention layers, and takes into account additional replacements that may arise.
|
158 |
+
|
159 |
+
Assigns the weights to the new checkpoint.
|
160 |
+
"""
|
161 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
162 |
+
|
163 |
+
# Splits the attention layers into three variables.
|
164 |
+
if attention_paths_to_split is not None:
|
165 |
+
for path, path_map in attention_paths_to_split.items():
|
166 |
+
old_tensor = old_checkpoint[path]
|
167 |
+
channels = old_tensor.shape[0] // 3
|
168 |
+
|
169 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
170 |
+
|
171 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
172 |
+
|
173 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
174 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
175 |
+
|
176 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
177 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
178 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
179 |
+
|
180 |
+
for path in paths:
|
181 |
+
new_path = path["new"]
|
182 |
+
|
183 |
+
# These have already been assigned
|
184 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
185 |
+
continue
|
186 |
+
|
187 |
+
# Global renaming happens here
|
188 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
189 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
190 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
191 |
+
|
192 |
+
if additional_replacements is not None:
|
193 |
+
for replacement in additional_replacements:
|
194 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
195 |
+
|
196 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
197 |
+
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
|
198 |
+
shape = old_checkpoint[path["old"]].shape
|
199 |
+
if is_attn_weight and len(shape) == 3:
|
200 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
201 |
+
elif is_attn_weight and len(shape) == 4:
|
202 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
203 |
+
else:
|
204 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
205 |
+
|
206 |
+
|
207 |
+
def conv_attn_to_linear(checkpoint):
|
208 |
+
keys = list(checkpoint.keys())
|
209 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
210 |
+
for key in keys:
|
211 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
212 |
+
if checkpoint[key].ndim > 2:
|
213 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
214 |
+
elif "proj_attn.weight" in key:
|
215 |
+
if checkpoint[key].ndim > 2:
|
216 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
217 |
+
|
218 |
+
|
219 |
+
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
220 |
+
"""
|
221 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
222 |
+
"""
|
223 |
+
if controlnet:
|
224 |
+
unet_params = original_config.model.params.control_stage_config.params
|
225 |
+
else:
|
226 |
+
unet_params = original_config.model.params.unet_config.params
|
227 |
+
|
228 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
229 |
+
|
230 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
231 |
+
|
232 |
+
down_block_types = []
|
233 |
+
resolution = 1
|
234 |
+
for i in range(len(block_out_channels)):
|
235 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
236 |
+
down_block_types.append(block_type)
|
237 |
+
if i != len(block_out_channels) - 1:
|
238 |
+
resolution *= 2
|
239 |
+
|
240 |
+
up_block_types = []
|
241 |
+
for i in range(len(block_out_channels)):
|
242 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
243 |
+
up_block_types.append(block_type)
|
244 |
+
resolution //= 2
|
245 |
+
|
246 |
+
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
247 |
+
|
248 |
+
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
249 |
+
use_linear_projection = (
|
250 |
+
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
251 |
+
)
|
252 |
+
if use_linear_projection:
|
253 |
+
# stable diffusion 2-base-512 and 2-768
|
254 |
+
if head_dim is None:
|
255 |
+
head_dim = [5, 10, 20, 20]
|
256 |
+
|
257 |
+
class_embed_type = None
|
258 |
+
projection_class_embeddings_input_dim = None
|
259 |
+
|
260 |
+
if "num_classes" in unet_params:
|
261 |
+
if unet_params.num_classes == "sequential":
|
262 |
+
class_embed_type = "projection"
|
263 |
+
assert "adm_in_channels" in unet_params
|
264 |
+
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
265 |
+
else:
|
266 |
+
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
|
267 |
+
|
268 |
+
config = {
|
269 |
+
"sample_size": image_size // vae_scale_factor,
|
270 |
+
"in_channels": unet_params.in_channels,
|
271 |
+
"down_block_types": tuple(down_block_types),
|
272 |
+
"block_out_channels": tuple(block_out_channels),
|
273 |
+
"layers_per_block": unet_params.num_res_blocks,
|
274 |
+
"cross_attention_dim": unet_params.context_dim,
|
275 |
+
"attention_head_dim": head_dim,
|
276 |
+
"use_linear_projection": use_linear_projection,
|
277 |
+
"class_embed_type": class_embed_type,
|
278 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
279 |
+
}
|
280 |
+
|
281 |
+
if not controlnet:
|
282 |
+
config["out_channels"] = unet_params.out_channels
|
283 |
+
config["up_block_types"] = tuple(up_block_types)
|
284 |
+
|
285 |
+
return config
|
286 |
+
|
287 |
+
|
288 |
+
def create_vae_diffusers_config(original_config, image_size: int):
|
289 |
+
"""
|
290 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
291 |
+
"""
|
292 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
293 |
+
_ = original_config.model.params.first_stage_config.params.embed_dim
|
294 |
+
|
295 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
296 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
297 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
298 |
+
|
299 |
+
config = {
|
300 |
+
"sample_size": image_size,
|
301 |
+
"in_channels": vae_params.in_channels,
|
302 |
+
"out_channels": vae_params.out_ch,
|
303 |
+
"down_block_types": tuple(down_block_types),
|
304 |
+
"up_block_types": tuple(up_block_types),
|
305 |
+
"block_out_channels": tuple(block_out_channels),
|
306 |
+
"latent_channels": vae_params.z_channels,
|
307 |
+
"layers_per_block": vae_params.num_res_blocks,
|
308 |
+
}
|
309 |
+
return config
|
310 |
+
|
311 |
+
|
312 |
+
def create_diffusers_schedular(original_config):
|
313 |
+
schedular = DDIMScheduler(
|
314 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
315 |
+
beta_start=original_config.model.params.linear_start,
|
316 |
+
beta_end=original_config.model.params.linear_end,
|
317 |
+
beta_schedule="scaled_linear",
|
318 |
+
)
|
319 |
+
return schedular
|
320 |
+
|
321 |
+
|
322 |
+
def create_ldm_bert_config(original_config):
|
323 |
+
bert_params = original_config.model.parms.cond_stage_config.params
|
324 |
+
config = LDMBertConfig(
|
325 |
+
d_model=bert_params.n_embed,
|
326 |
+
encoder_layers=bert_params.n_layer,
|
327 |
+
encoder_ffn_dim=bert_params.n_embed * 4,
|
328 |
+
)
|
329 |
+
return config
|
330 |
+
|
331 |
+
|
332 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
|
333 |
+
"""
|
334 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
335 |
+
"""
|
336 |
+
|
337 |
+
# extract state_dict for UNet
|
338 |
+
unet_state_dict = {}
|
339 |
+
keys = list(checkpoint.keys())
|
340 |
+
|
341 |
+
if controlnet:
|
342 |
+
unet_key = "control_model."
|
343 |
+
else:
|
344 |
+
unet_key = "model.diffusion_model."
|
345 |
+
|
346 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
347 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
348 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
349 |
+
print(
|
350 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
351 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
352 |
+
)
|
353 |
+
for key in keys:
|
354 |
+
if key.startswith("model.diffusion_model"):
|
355 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
356 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
357 |
+
else:
|
358 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
359 |
+
print(
|
360 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
361 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
362 |
+
)
|
363 |
+
|
364 |
+
for key in keys:
|
365 |
+
if key.startswith(unet_key):
|
366 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
367 |
+
|
368 |
+
new_checkpoint = {}
|
369 |
+
|
370 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
371 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
372 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
373 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
374 |
+
|
375 |
+
if config["class_embed_type"] is None:
|
376 |
+
# No parameters to port
|
377 |
+
...
|
378 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
379 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
380 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
381 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
382 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
383 |
+
else:
|
384 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
385 |
+
|
386 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
387 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
388 |
+
|
389 |
+
if not controlnet:
|
390 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
391 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
392 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
393 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
394 |
+
|
395 |
+
# Retrieves the keys for the input blocks only
|
396 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
397 |
+
input_blocks = {
|
398 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
399 |
+
for layer_id in range(num_input_blocks)
|
400 |
+
}
|
401 |
+
|
402 |
+
# Retrieves the keys for the middle blocks only
|
403 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
404 |
+
middle_blocks = {
|
405 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
406 |
+
for layer_id in range(num_middle_blocks)
|
407 |
+
}
|
408 |
+
|
409 |
+
# Retrieves the keys for the output blocks only
|
410 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
411 |
+
output_blocks = {
|
412 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
413 |
+
for layer_id in range(num_output_blocks)
|
414 |
+
}
|
415 |
+
|
416 |
+
for i in range(1, num_input_blocks):
|
417 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
418 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
419 |
+
|
420 |
+
resnets = [
|
421 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
422 |
+
]
|
423 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
424 |
+
|
425 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
426 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
427 |
+
f"input_blocks.{i}.0.op.weight"
|
428 |
+
)
|
429 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
430 |
+
f"input_blocks.{i}.0.op.bias"
|
431 |
+
)
|
432 |
+
|
433 |
+
paths = renew_resnet_paths(resnets)
|
434 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
435 |
+
assign_to_checkpoint(
|
436 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
437 |
+
)
|
438 |
+
|
439 |
+
if len(attentions):
|
440 |
+
paths = renew_attention_paths(attentions)
|
441 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
442 |
+
assign_to_checkpoint(
|
443 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
444 |
+
)
|
445 |
+
|
446 |
+
resnet_0 = middle_blocks[0]
|
447 |
+
attentions = middle_blocks[1]
|
448 |
+
resnet_1 = middle_blocks[2]
|
449 |
+
|
450 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
451 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
452 |
+
|
453 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
454 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
455 |
+
|
456 |
+
attentions_paths = renew_attention_paths(attentions)
|
457 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
458 |
+
assign_to_checkpoint(
|
459 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
460 |
+
)
|
461 |
+
|
462 |
+
for i in range(num_output_blocks):
|
463 |
+
block_id = i // (config["layers_per_block"] + 1)
|
464 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
465 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
466 |
+
output_block_list = {}
|
467 |
+
|
468 |
+
for layer in output_block_layers:
|
469 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
470 |
+
if layer_id in output_block_list:
|
471 |
+
output_block_list[layer_id].append(layer_name)
|
472 |
+
else:
|
473 |
+
output_block_list[layer_id] = [layer_name]
|
474 |
+
|
475 |
+
if len(output_block_list) > 1:
|
476 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
477 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
478 |
+
|
479 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
480 |
+
paths = renew_resnet_paths(resnets)
|
481 |
+
|
482 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
483 |
+
assign_to_checkpoint(
|
484 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
485 |
+
)
|
486 |
+
|
487 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
488 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
489 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
490 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
491 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
492 |
+
]
|
493 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
494 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
495 |
+
]
|
496 |
+
|
497 |
+
# Clear attentions as they have been attributed above.
|
498 |
+
if len(attentions) == 2:
|
499 |
+
attentions = []
|
500 |
+
|
501 |
+
if len(attentions):
|
502 |
+
paths = renew_attention_paths(attentions)
|
503 |
+
meta_path = {
|
504 |
+
"old": f"output_blocks.{i}.1",
|
505 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
506 |
+
}
|
507 |
+
assign_to_checkpoint(
|
508 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
509 |
+
)
|
510 |
+
else:
|
511 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
512 |
+
for path in resnet_0_paths:
|
513 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
514 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
515 |
+
|
516 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
517 |
+
|
518 |
+
if controlnet:
|
519 |
+
# conditioning embedding
|
520 |
+
|
521 |
+
orig_index = 0
|
522 |
+
|
523 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
524 |
+
f"input_hint_block.{orig_index}.weight"
|
525 |
+
)
|
526 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
527 |
+
f"input_hint_block.{orig_index}.bias"
|
528 |
+
)
|
529 |
+
|
530 |
+
orig_index += 2
|
531 |
+
|
532 |
+
diffusers_index = 0
|
533 |
+
|
534 |
+
while diffusers_index < 6:
|
535 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
536 |
+
f"input_hint_block.{orig_index}.weight"
|
537 |
+
)
|
538 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
539 |
+
f"input_hint_block.{orig_index}.bias"
|
540 |
+
)
|
541 |
+
diffusers_index += 1
|
542 |
+
orig_index += 2
|
543 |
+
|
544 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
545 |
+
f"input_hint_block.{orig_index}.weight"
|
546 |
+
)
|
547 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
548 |
+
f"input_hint_block.{orig_index}.bias"
|
549 |
+
)
|
550 |
+
|
551 |
+
# down blocks
|
552 |
+
for i in range(num_input_blocks):
|
553 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
554 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
555 |
+
|
556 |
+
# mid block
|
557 |
+
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
558 |
+
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
559 |
+
|
560 |
+
return new_checkpoint
|
561 |
+
|
562 |
+
|
563 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
564 |
+
# extract state dict for VAE
|
565 |
+
vae_state_dict = {}
|
566 |
+
vae_key = "first_stage_model."
|
567 |
+
keys = list(checkpoint.keys())
|
568 |
+
for key in keys:
|
569 |
+
if key.startswith(vae_key):
|
570 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
571 |
+
|
572 |
+
new_checkpoint = {}
|
573 |
+
|
574 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
575 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
576 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
577 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
578 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
579 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
580 |
+
|
581 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
582 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
583 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
584 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
585 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
586 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
587 |
+
|
588 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
589 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
590 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
591 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
592 |
+
|
593 |
+
# Retrieves the keys for the encoder down blocks only
|
594 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
595 |
+
down_blocks = {
|
596 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
597 |
+
}
|
598 |
+
|
599 |
+
# Retrieves the keys for the decoder up blocks only
|
600 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
601 |
+
up_blocks = {
|
602 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
603 |
+
}
|
604 |
+
|
605 |
+
for i in range(num_down_blocks):
|
606 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
607 |
+
|
608 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
609 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
610 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
611 |
+
)
|
612 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
613 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
614 |
+
)
|
615 |
+
|
616 |
+
paths = renew_vae_resnet_paths(resnets)
|
617 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
618 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
619 |
+
|
620 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
621 |
+
num_mid_res_blocks = 2
|
622 |
+
for i in range(1, num_mid_res_blocks + 1):
|
623 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
624 |
+
|
625 |
+
paths = renew_vae_resnet_paths(resnets)
|
626 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
627 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
628 |
+
|
629 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
630 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
631 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
632 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
633 |
+
conv_attn_to_linear(new_checkpoint)
|
634 |
+
|
635 |
+
for i in range(num_up_blocks):
|
636 |
+
block_id = num_up_blocks - 1 - i
|
637 |
+
resnets = [
|
638 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
639 |
+
]
|
640 |
+
|
641 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
642 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
643 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
644 |
+
]
|
645 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
646 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
647 |
+
]
|
648 |
+
|
649 |
+
paths = renew_vae_resnet_paths(resnets)
|
650 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
651 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
652 |
+
|
653 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
654 |
+
num_mid_res_blocks = 2
|
655 |
+
for i in range(1, num_mid_res_blocks + 1):
|
656 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
657 |
+
|
658 |
+
paths = renew_vae_resnet_paths(resnets)
|
659 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
660 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
661 |
+
|
662 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
663 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
664 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
665 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
666 |
+
conv_attn_to_linear(new_checkpoint)
|
667 |
+
return new_checkpoint
|
668 |
+
|
669 |
+
|
670 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
671 |
+
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
672 |
+
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
673 |
+
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
674 |
+
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
675 |
+
|
676 |
+
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
677 |
+
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
678 |
+
|
679 |
+
def _copy_linear(hf_linear, pt_linear):
|
680 |
+
hf_linear.weight = pt_linear.weight
|
681 |
+
hf_linear.bias = pt_linear.bias
|
682 |
+
|
683 |
+
def _copy_layer(hf_layer, pt_layer):
|
684 |
+
# copy layer norms
|
685 |
+
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
686 |
+
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
687 |
+
|
688 |
+
# copy attn
|
689 |
+
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
690 |
+
|
691 |
+
# copy MLP
|
692 |
+
pt_mlp = pt_layer[1][1]
|
693 |
+
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
694 |
+
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
695 |
+
|
696 |
+
def _copy_layers(hf_layers, pt_layers):
|
697 |
+
for i, hf_layer in enumerate(hf_layers):
|
698 |
+
if i != 0:
|
699 |
+
i += i
|
700 |
+
pt_layer = pt_layers[i : i + 2]
|
701 |
+
_copy_layer(hf_layer, pt_layer)
|
702 |
+
|
703 |
+
hf_model = LDMBertModel(config).eval()
|
704 |
+
|
705 |
+
# copy embeds
|
706 |
+
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
707 |
+
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
708 |
+
|
709 |
+
# copy layer norm
|
710 |
+
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
711 |
+
|
712 |
+
# copy hidden layers
|
713 |
+
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
714 |
+
|
715 |
+
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
716 |
+
|
717 |
+
return hf_model
|
718 |
+
|
719 |
+
|
720 |
+
def convert_ldm_clip_checkpoint(checkpoint, pretrained_model_path):
|
721 |
+
text_model = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
|
722 |
+
keys = list(checkpoint.keys())
|
723 |
+
|
724 |
+
text_model_dict = {}
|
725 |
+
|
726 |
+
for key in keys:
|
727 |
+
if key.startswith("cond_stage_model.transformer"):
|
728 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
729 |
+
|
730 |
+
text_model.load_state_dict(text_model_dict)
|
731 |
+
|
732 |
+
return text_model
|
733 |
+
|
734 |
+
|
735 |
+
textenc_conversion_lst = [
|
736 |
+
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
737 |
+
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
738 |
+
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
739 |
+
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
740 |
+
]
|
741 |
+
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
742 |
+
|
743 |
+
textenc_transformer_conversion_lst = [
|
744 |
+
# (stable-diffusion, HF Diffusers)
|
745 |
+
("resblocks.", "text_model.encoder.layers."),
|
746 |
+
("ln_1", "layer_norm1"),
|
747 |
+
("ln_2", "layer_norm2"),
|
748 |
+
(".c_fc.", ".fc1."),
|
749 |
+
(".c_proj.", ".fc2."),
|
750 |
+
(".attn", ".self_attn"),
|
751 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
752 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
753 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
754 |
+
]
|
755 |
+
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
756 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
757 |
+
|
758 |
+
|
759 |
+
def convert_paint_by_example_checkpoint(checkpoint):
|
760 |
+
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
761 |
+
model = PaintByExampleImageEncoder(config)
|
762 |
+
|
763 |
+
keys = list(checkpoint.keys())
|
764 |
+
|
765 |
+
text_model_dict = {}
|
766 |
+
|
767 |
+
for key in keys:
|
768 |
+
if key.startswith("cond_stage_model.transformer"):
|
769 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
770 |
+
|
771 |
+
# load clip vision
|
772 |
+
model.model.load_state_dict(text_model_dict)
|
773 |
+
|
774 |
+
# load mapper
|
775 |
+
keys_mapper = {
|
776 |
+
k[len("cond_stage_model.mapper.res") :]: v
|
777 |
+
for k, v in checkpoint.items()
|
778 |
+
if k.startswith("cond_stage_model.mapper")
|
779 |
+
}
|
780 |
+
|
781 |
+
MAPPING = {
|
782 |
+
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
783 |
+
"attn.c_proj": ["attn1.to_out.0"],
|
784 |
+
"ln_1": ["norm1"],
|
785 |
+
"ln_2": ["norm3"],
|
786 |
+
"mlp.c_fc": ["ff.net.0.proj"],
|
787 |
+
"mlp.c_proj": ["ff.net.2"],
|
788 |
+
}
|
789 |
+
|
790 |
+
mapped_weights = {}
|
791 |
+
for key, value in keys_mapper.items():
|
792 |
+
prefix = key[: len("blocks.i")]
|
793 |
+
suffix = key.split(prefix)[-1].split(".")[-1]
|
794 |
+
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
795 |
+
mapped_names = MAPPING[name]
|
796 |
+
|
797 |
+
num_splits = len(mapped_names)
|
798 |
+
for i, mapped_name in enumerate(mapped_names):
|
799 |
+
new_name = ".".join([prefix, mapped_name, suffix])
|
800 |
+
shape = value.shape[0] // num_splits
|
801 |
+
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
802 |
+
|
803 |
+
model.mapper.load_state_dict(mapped_weights)
|
804 |
+
|
805 |
+
# load final layer norm
|
806 |
+
model.final_layer_norm.load_state_dict(
|
807 |
+
{
|
808 |
+
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
809 |
+
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
810 |
+
}
|
811 |
+
)
|
812 |
+
|
813 |
+
# load final proj
|
814 |
+
model.proj_out.load_state_dict(
|
815 |
+
{
|
816 |
+
"bias": checkpoint["proj_out.bias"],
|
817 |
+
"weight": checkpoint["proj_out.weight"],
|
818 |
+
}
|
819 |
+
)
|
820 |
+
|
821 |
+
# load uncond vector
|
822 |
+
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
823 |
+
return model
|
824 |
+
|
825 |
+
|
826 |
+
def convert_open_clip_checkpoint(checkpoint):
|
827 |
+
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
828 |
+
|
829 |
+
keys = list(checkpoint.keys())
|
830 |
+
|
831 |
+
text_model_dict = {}
|
832 |
+
|
833 |
+
if "cond_stage_model.model.text_projection" in checkpoint:
|
834 |
+
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
835 |
+
else:
|
836 |
+
d_model = 1024
|
837 |
+
|
838 |
+
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
839 |
+
|
840 |
+
for key in keys:
|
841 |
+
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
842 |
+
continue
|
843 |
+
if key in textenc_conversion_map:
|
844 |
+
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
845 |
+
if key.startswith("cond_stage_model.model.transformer."):
|
846 |
+
new_key = key[len("cond_stage_model.model.transformer.") :]
|
847 |
+
if new_key.endswith(".in_proj_weight"):
|
848 |
+
new_key = new_key[: -len(".in_proj_weight")]
|
849 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
850 |
+
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
851 |
+
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
852 |
+
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
853 |
+
elif new_key.endswith(".in_proj_bias"):
|
854 |
+
new_key = new_key[: -len(".in_proj_bias")]
|
855 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
856 |
+
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
857 |
+
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
858 |
+
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
859 |
+
else:
|
860 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
861 |
+
|
862 |
+
text_model_dict[new_key] = checkpoint[key]
|
863 |
+
|
864 |
+
text_model.load_state_dict(text_model_dict)
|
865 |
+
|
866 |
+
return text_model
|
867 |
+
|
868 |
+
|
869 |
+
def stable_unclip_image_encoder(original_config):
|
870 |
+
"""
|
871 |
+
Returns the image processor and clip image encoder for the img2img unclip pipeline.
|
872 |
+
|
873 |
+
We currently know of two types of stable unclip models which separately use the clip and the openclip image
|
874 |
+
encoders.
|
875 |
+
"""
|
876 |
+
|
877 |
+
image_embedder_config = original_config.model.params.embedder_config
|
878 |
+
|
879 |
+
sd_clip_image_embedder_class = image_embedder_config.target
|
880 |
+
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
|
881 |
+
|
882 |
+
if sd_clip_image_embedder_class == "ClipImageEmbedder":
|
883 |
+
clip_model_name = image_embedder_config.params.model
|
884 |
+
|
885 |
+
if clip_model_name == "ViT-L/14":
|
886 |
+
feature_extractor = CLIPImageProcessor()
|
887 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
888 |
+
else:
|
889 |
+
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
|
890 |
+
|
891 |
+
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
892 |
+
feature_extractor = CLIPImageProcessor()
|
893 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
894 |
+
else:
|
895 |
+
raise NotImplementedError(
|
896 |
+
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
897 |
+
)
|
898 |
+
|
899 |
+
return feature_extractor, image_encoder
|
900 |
+
|
901 |
+
|
902 |
+
def stable_unclip_image_noising_components(
|
903 |
+
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
|
904 |
+
):
|
905 |
+
"""
|
906 |
+
Returns the noising components for the img2img and txt2img unclip pipelines.
|
907 |
+
|
908 |
+
Converts the stability noise augmentor into
|
909 |
+
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
|
910 |
+
2. a `DDPMScheduler` for holding the noise schedule
|
911 |
+
|
912 |
+
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
|
913 |
+
"""
|
914 |
+
noise_aug_config = original_config.model.params.noise_aug_config
|
915 |
+
noise_aug_class = noise_aug_config.target
|
916 |
+
noise_aug_class = noise_aug_class.split(".")[-1]
|
917 |
+
|
918 |
+
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
|
919 |
+
noise_aug_config = noise_aug_config.params
|
920 |
+
embedding_dim = noise_aug_config.timestep_dim
|
921 |
+
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
|
922 |
+
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
|
923 |
+
|
924 |
+
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
|
925 |
+
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
|
926 |
+
|
927 |
+
if "clip_stats_path" in noise_aug_config:
|
928 |
+
if clip_stats_path is None:
|
929 |
+
raise ValueError("This stable unclip config requires a `clip_stats_path`")
|
930 |
+
|
931 |
+
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
|
932 |
+
clip_mean = clip_mean[None, :]
|
933 |
+
clip_std = clip_std[None, :]
|
934 |
+
|
935 |
+
clip_stats_state_dict = {
|
936 |
+
"mean": clip_mean,
|
937 |
+
"std": clip_std,
|
938 |
+
}
|
939 |
+
|
940 |
+
image_normalizer.load_state_dict(clip_stats_state_dict)
|
941 |
+
else:
|
942 |
+
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
|
943 |
+
|
944 |
+
return image_normalizer, image_noising_scheduler
|
945 |
+
|
946 |
+
|
947 |
+
def convert_controlnet_checkpoint(
|
948 |
+
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
949 |
+
):
|
950 |
+
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
951 |
+
ctrlnet_config["upcast_attention"] = upcast_attention
|
952 |
+
|
953 |
+
ctrlnet_config.pop("sample_size")
|
954 |
+
|
955 |
+
controlnet_model = ControlNetModel(**ctrlnet_config)
|
956 |
+
|
957 |
+
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
958 |
+
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
959 |
+
)
|
960 |
+
|
961 |
+
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
962 |
+
|
963 |
+
return controlnet_model
|
musev/utils/convert_lora_safetensor_to_diffusers.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Conversion script for the LoRA's safetensors checkpoints. """
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
from diffusers import StableDiffusionPipeline
|
24 |
+
import pdb
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
|
29 |
+
# directly update weight in diffusers model
|
30 |
+
for key in state_dict:
|
31 |
+
# only process lora down key
|
32 |
+
if "up." in key: continue
|
33 |
+
|
34 |
+
up_key = key.replace(".down.", ".up.")
|
35 |
+
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
|
36 |
+
model_key = model_key.replace("to_out.", "to_out.0.")
|
37 |
+
layer_infos = model_key.split(".")[:-1]
|
38 |
+
|
39 |
+
curr_layer = pipeline.unet
|
40 |
+
while len(layer_infos) > 0:
|
41 |
+
temp_name = layer_infos.pop(0)
|
42 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
43 |
+
|
44 |
+
weight_down = state_dict[key]
|
45 |
+
weight_up = state_dict[up_key]
|
46 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
47 |
+
|
48 |
+
return pipeline
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
|
53 |
+
# load base model
|
54 |
+
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
55 |
+
|
56 |
+
# load LoRA weight from .safetensors
|
57 |
+
# state_dict = load_file(checkpoint_path)
|
58 |
+
|
59 |
+
visited = []
|
60 |
+
|
61 |
+
# directly update weight in diffusers model
|
62 |
+
for key in state_dict:
|
63 |
+
# it is suggested to print out the key, it usually will be something like below
|
64 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
65 |
+
|
66 |
+
# as we have set the alpha beforehand, so just skip
|
67 |
+
if ".alpha" in key or key in visited:
|
68 |
+
continue
|
69 |
+
|
70 |
+
if "text" in key:
|
71 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
72 |
+
curr_layer = pipeline.text_encoder
|
73 |
+
else:
|
74 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
75 |
+
curr_layer = pipeline.unet
|
76 |
+
|
77 |
+
# find the target layer
|
78 |
+
temp_name = layer_infos.pop(0)
|
79 |
+
while len(layer_infos) > -1:
|
80 |
+
try:
|
81 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
82 |
+
if len(layer_infos) > 0:
|
83 |
+
temp_name = layer_infos.pop(0)
|
84 |
+
elif len(layer_infos) == 0:
|
85 |
+
break
|
86 |
+
except Exception:
|
87 |
+
if len(temp_name) > 0:
|
88 |
+
temp_name += "_" + layer_infos.pop(0)
|
89 |
+
else:
|
90 |
+
temp_name = layer_infos.pop(0)
|
91 |
+
|
92 |
+
pair_keys = []
|
93 |
+
if "lora_down" in key:
|
94 |
+
pair_keys.append(key.replace("lora_down", "lora_up"))
|
95 |
+
pair_keys.append(key)
|
96 |
+
else:
|
97 |
+
pair_keys.append(key)
|
98 |
+
pair_keys.append(key.replace("lora_up", "lora_down"))
|
99 |
+
|
100 |
+
# update weight
|
101 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
102 |
+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
103 |
+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
104 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
|
105 |
+
else:
|
106 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
107 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
108 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
109 |
+
|
110 |
+
# update visited list
|
111 |
+
for item in pair_keys:
|
112 |
+
visited.append(item)
|
113 |
+
|
114 |
+
return pipeline
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
parser = argparse.ArgumentParser()
|
119 |
+
|
120 |
+
parser.add_argument(
|
121 |
+
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
125 |
+
)
|
126 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
127 |
+
parser.add_argument(
|
128 |
+
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--lora_prefix_text_encoder",
|
132 |
+
default="lora_te",
|
133 |
+
type=str,
|
134 |
+
help="The prefix of text encoder weight in safetensors",
|
135 |
+
)
|
136 |
+
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
|
137 |
+
parser.add_argument(
|
138 |
+
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
|
139 |
+
)
|
140 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
141 |
+
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
base_model_path = args.base_model_path
|
145 |
+
checkpoint_path = args.checkpoint_path
|
146 |
+
dump_path = args.dump_path
|
147 |
+
lora_prefix_unet = args.lora_prefix_unet
|
148 |
+
lora_prefix_text_encoder = args.lora_prefix_text_encoder
|
149 |
+
alpha = args.alpha
|
150 |
+
|
151 |
+
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
|
152 |
+
|
153 |
+
pipe = pipe.to(args.device)
|
154 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|