Baraaqasem commited on
Commit
5d32408
·
verified ·
1 Parent(s): 413d4d0

Upload 585 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/videogen_hub/pipelines/__init__.py +0 -0
  2. src/videogen_hub/pipelines/cogvideo/__init__.py +4 -0
  3. src/videogen_hub/pipelines/cogvideo/cogvideo_pipeline.py +612 -0
  4. src/videogen_hub/pipelines/cogvideo/cogvideo_src/LICENSE +201 -0
  5. src/videogen_hub/pipelines/cogvideo/cogvideo_src/Model_License +79 -0
  6. src/videogen_hub/pipelines/cogvideo/cogvideo_src/__init__.py +0 -0
  7. src/videogen_hub/pipelines/cogvideo/cogvideo_src/cluster_label2.npy +3 -0
  8. src/videogen_hub/pipelines/cogvideo/cogvideo_src/coglm_strategy.py +101 -0
  9. src/videogen_hub/pipelines/cogvideo/cogvideo_src/cogvideo_pipeline.py +1341 -0
  10. src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/__init__.py +0 -0
  11. src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_cache_model.py +695 -0
  12. src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_model.py +543 -0
  13. src/videogen_hub/pipelines/cogvideo/cogvideo_src/pretrain_cogvideo.py +184 -0
  14. src/videogen_hub/pipelines/cogvideo/cogvideo_src/requirements.txt +4 -0
  15. src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/__init__.py +17 -0
  16. src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/cluster_label2.npy +3 -0
  17. src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/direct_sr.py +117 -0
  18. src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_model.py +225 -0
  19. src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_sampling.py +204 -0
  20. src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/iterative_sr.py +118 -0
  21. src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_model.py +232 -0
  22. src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_sampling.py +168 -0
  23. src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/sr_group.py +49 -0
  24. src/videogen_hub/pipelines/consisti2v/LICENSE +21 -0
  25. src/videogen_hub/pipelines/consisti2v/__init__.py +0 -0
  26. src/videogen_hub/pipelines/consisti2v/configs/__init__.py +0 -0
  27. src/videogen_hub/pipelines/consisti2v/configs/inference/__init__.py +0 -0
  28. src/videogen_hub/pipelines/consisti2v/configs/inference/inference.yaml +48 -0
  29. src/videogen_hub/pipelines/consisti2v/configs/inference/inference_autoregress.yaml +49 -0
  30. src/videogen_hub/pipelines/consisti2v/configs/prompts/__init__.py +0 -0
  31. src/videogen_hub/pipelines/consisti2v/configs/prompts/default.yaml +16 -0
  32. src/videogen_hub/pipelines/consisti2v/configs/training/__init__.py +0 -0
  33. src/videogen_hub/pipelines/consisti2v/configs/training/training.yaml +92 -0
  34. src/videogen_hub/pipelines/consisti2v/consisti2v/__init__.py +0 -0
  35. src/videogen_hub/pipelines/consisti2v/consisti2v/data/__init__.py +0 -0
  36. src/videogen_hub/pipelines/consisti2v/consisti2v/data/dataset.py +315 -0
  37. src/videogen_hub/pipelines/consisti2v/consisti2v/models/__init__.py +0 -0
  38. src/videogen_hub/pipelines/consisti2v/consisti2v/models/rotary_embedding.py +280 -0
  39. src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_attention.py +809 -0
  40. src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_transformer_blocks.py +564 -0
  41. src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet.py +1371 -0
  42. src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet_blocks.py +1159 -0
  43. src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/__init__.py +0 -0
  44. src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_autoregress_animation.py +615 -0
  45. src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_conditional_animation.py +695 -0
  46. src/videogen_hub/pipelines/consisti2v/consisti2v/utils/__init__.py +0 -0
  47. src/videogen_hub/pipelines/consisti2v/consisti2v/utils/frameinit_utils.py +142 -0
  48. src/videogen_hub/pipelines/consisti2v/consisti2v/utils/util.py +165 -0
  49. src/videogen_hub/pipelines/consisti2v/scripts/__init__.py +0 -0
  50. src/videogen_hub/pipelines/consisti2v/scripts/animate.py +247 -0
src/videogen_hub/pipelines/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/cogvideo/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.insert(0, "./src/videogen_hub/pipelines/cogvideo/")
4
+ sys.path.insert(0, "./src/videogen_hub/pipelines/cogvideo/cogvideo_src")
src/videogen_hub/pipelines/cogvideo/cogvideo_pipeline.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from videogen_hub.pipelines.cogvideo.cogvideo_src.cogvideo_pipeline import (
2
+ InferenceModel_Interpolate,
3
+ InferenceModel_Sequential,
4
+ my_filling_sequence,
5
+ get_masks_and_position_ids_stage1,
6
+ get_masks_and_position_ids_stage2,
7
+ my_save_multiple_images,
8
+ )
9
+ from videogen_hub.depend.icetk import icetk as tokenizer
10
+ from videogen_hub.pipelines.cogvideo.cogvideo_src.coglm_strategy import (
11
+ CoglmStrategy,
12
+ )
13
+ from videogen_hub.pipelines.cogvideo.cogvideo_src.sr_pipeline import (
14
+ DirectSuperResolution,
15
+ )
16
+ from SwissArmyTransformer.resources import auto_create
17
+ import time, logging, sys, os, torch
18
+ import torch.distributed as dist
19
+
20
+ # path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
21
+
22
+
23
+ def pipeline(args, raw_text, height, width, duration):
24
+ # model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1')
25
+ # model_stage1.eval()
26
+ # parent_givan_tokens = process_stage1(model_stage1, raw_text, duration=4.0, video_raw_text=raw_text, video_guidance_text="视频",
27
+ # image_text_suffix=" 高清摄影",
28
+ # outputdir=None, batch_size=args.batch_size)
29
+
30
+ # process_stage2(model_stage2, raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
31
+ # video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
32
+ # outputdir=path,
33
+ # gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
34
+
35
+ assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1
36
+ rank_id = args.device % args.parallel_size
37
+ generate_frame_num = args.generate_frame_num
38
+
39
+ if args.stage_1 or args.both_stages:
40
+ model_stage1, args = InferenceModel_Sequential.from_pretrained(
41
+ args, "cogvideo-stage1"
42
+ )
43
+ model_stage1.eval()
44
+ if args.both_stages:
45
+ model_stage1 = model_stage1.cpu()
46
+
47
+ if args.stage_2 or args.both_stages:
48
+ model_stage2, args = InferenceModel_Interpolate.from_pretrained(
49
+ args, "cogvideo-stage2"
50
+ )
51
+ model_stage2.eval()
52
+ if args.both_stages:
53
+ model_stage2 = model_stage2.cpu()
54
+
55
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
56
+ strategy_cogview2 = CoglmStrategy(invalid_slices, temperature=1.0, top_k=16)
57
+ strategy_cogvideo = CoglmStrategy(
58
+ invalid_slices,
59
+ temperature=args.temperature,
60
+ top_k=args.top_k,
61
+ temperature2=args.coglm_temperature2,
62
+ )
63
+ if not args.stage_1:
64
+ # from sr_pipeline import DirectSuperResolution
65
+ dsr_path = auto_create(
66
+ "cogview2-dsr", path=None
67
+ ) # path=os.getenv('SAT_HOME', '~/.sat_models')
68
+ dsr = DirectSuperResolution(args, dsr_path, max_bz=12, onCUDA=False)
69
+
70
+ def process_stage2(
71
+ model,
72
+ seq_text,
73
+ duration,
74
+ video_raw_text=None,
75
+ video_guidance_text="视频",
76
+ parent_given_tokens=None,
77
+ conddir=None,
78
+ outputdir=None,
79
+ gpu_rank=0,
80
+ gpu_parallel_size=1,
81
+ ):
82
+ stage2_starttime = time.time()
83
+ use_guidance = args.use_guidance_stage2
84
+ if args.both_stages:
85
+ move_start_time = time.time()
86
+ logging.debug("moving stage-2 model to cuda")
87
+ model = model.cuda()
88
+ logging.debug(
89
+ "moving in stage-2 model takes time: {:.2f}".format(
90
+ time.time() - move_start_time
91
+ )
92
+ )
93
+
94
+ try:
95
+ if parent_given_tokens is None:
96
+ assert conddir is not None
97
+ parent_given_tokens = torch.load(
98
+ os.path.join(conddir, "frame_tokens.pt"), map_location="cpu"
99
+ )
100
+ sample_num_allgpu = parent_given_tokens.shape[0]
101
+ sample_num = sample_num_allgpu // gpu_parallel_size
102
+ assert sample_num * gpu_parallel_size == sample_num_allgpu
103
+ parent_given_tokens = parent_given_tokens[
104
+ gpu_rank * sample_num : (gpu_rank + 1) * sample_num
105
+ ]
106
+ except:
107
+ logging.critical("No frame_tokens found in interpolation, skip")
108
+ return False
109
+
110
+ # CogVideo Stage2 Generation
111
+ while (
112
+ duration >= 0.5
113
+ ): # TODO: You can change the boundary to change the frame rate
114
+ parent_given_tokens_num = parent_given_tokens.shape[1]
115
+ generate_batchsize_persample = (parent_given_tokens_num - 1) // 2
116
+ generate_batchsize_total = generate_batchsize_persample * sample_num
117
+ total_frames = generate_frame_num
118
+ frame_len = 400
119
+ enc_text = tokenizer.encode(seq_text)
120
+ enc_duration = tokenizer.encode(str(float(duration)) + "秒")
121
+ seq = (
122
+ enc_duration
123
+ + [tokenizer["<n>"]]
124
+ + enc_text
125
+ + [tokenizer["<start_of_image>"]]
126
+ + [-1] * 400 * generate_frame_num
127
+ )
128
+ text_len = len(seq) - frame_len * generate_frame_num - 1
129
+
130
+ logging.info(
131
+ "[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(
132
+ int(4 / duration), tokenizer.decode(enc_text)
133
+ )
134
+ )
135
+
136
+ # generation
137
+ seq = (
138
+ torch.cuda.LongTensor(seq, device=args.device)
139
+ .unsqueeze(0)
140
+ .repeat(generate_batchsize_total, 1)
141
+ )
142
+ for sample_i in range(sample_num):
143
+ for i in range(generate_batchsize_persample):
144
+ seq[sample_i * generate_batchsize_persample + i][
145
+ text_len + 1 : text_len + 1 + 400
146
+ ] = parent_given_tokens[sample_i][2 * i]
147
+ seq[sample_i * generate_batchsize_persample + i][
148
+ text_len + 1 + 400 : text_len + 1 + 800
149
+ ] = parent_given_tokens[sample_i][2 * i + 1]
150
+ seq[sample_i * generate_batchsize_persample + i][
151
+ text_len + 1 + 800 : text_len + 1 + 1200
152
+ ] = parent_given_tokens[sample_i][2 * i + 2]
153
+
154
+ if use_guidance:
155
+ guider_seq = (
156
+ enc_duration
157
+ + [tokenizer["<n>"]]
158
+ + tokenizer.encode(video_guidance_text)
159
+ + [tokenizer["<start_of_image>"]]
160
+ + [-1] * 400 * generate_frame_num
161
+ )
162
+ guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
163
+ guider_seq = (
164
+ torch.cuda.LongTensor(guider_seq, device=args.device)
165
+ .unsqueeze(0)
166
+ .repeat(generate_batchsize_total, 1)
167
+ )
168
+ for sample_i in range(sample_num):
169
+ for i in range(generate_batchsize_persample):
170
+ guider_seq[sample_i * generate_batchsize_persample + i][
171
+ text_len + 1 : text_len + 1 + 400
172
+ ] = parent_given_tokens[sample_i][2 * i]
173
+ guider_seq[sample_i * generate_batchsize_persample + i][
174
+ text_len + 1 + 400 : text_len + 1 + 800
175
+ ] = parent_given_tokens[sample_i][2 * i + 1]
176
+ guider_seq[sample_i * generate_batchsize_persample + i][
177
+ text_len + 1 + 800 : text_len + 1 + 1200
178
+ ] = parent_given_tokens[sample_i][2 * i + 2]
179
+ video_log_text_attention_weights = 0
180
+ else:
181
+ guider_seq = None
182
+ guider_text_len = 0
183
+ video_log_text_attention_weights = 1.4
184
+
185
+ mbz = args.max_inference_batch_size
186
+
187
+ assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
188
+ output_list = []
189
+ start_time = time.time()
190
+ for tim in range(max(generate_batchsize_total // mbz, 1)):
191
+ input_seq = (
192
+ seq[: min(generate_batchsize_total, mbz)].clone()
193
+ if tim == 0
194
+ else seq[mbz * tim : mbz * (tim + 1)].clone()
195
+ )
196
+ guider_seq2 = (
197
+ (
198
+ guider_seq[: min(generate_batchsize_total, mbz)].clone()
199
+ if tim == 0
200
+ else guider_seq[mbz * tim : mbz * (tim + 1)].clone()
201
+ )
202
+ if guider_seq is not None
203
+ else None
204
+ )
205
+ output_list.append(
206
+ my_filling_sequence(
207
+ model,
208
+ args,
209
+ input_seq,
210
+ batch_size=min(generate_batchsize_total, mbz),
211
+ get_masks_and_position_ids=get_masks_and_position_ids_stage2,
212
+ text_len=text_len,
213
+ frame_len=frame_len,
214
+ strategy=strategy_cogview2,
215
+ strategy2=strategy_cogvideo,
216
+ log_text_attention_weights=video_log_text_attention_weights,
217
+ mode_stage1=False,
218
+ guider_seq=guider_seq2,
219
+ guider_text_len=guider_text_len,
220
+ guidance_alpha=args.guidance_alpha,
221
+ limited_spatial_channel_mem=True,
222
+ )[0]
223
+ )
224
+ logging.info(
225
+ "Duration {:.2f}, Taken time {:.2f}\n".format(
226
+ duration, time.time() - start_time
227
+ )
228
+ )
229
+
230
+ output_tokens = torch.cat(output_list, dim=0)
231
+ output_tokens = output_tokens[
232
+ :, text_len + 1 : text_len + 1 + (total_frames) * 400
233
+ ].reshape(sample_num, -1, 400 * total_frames)
234
+ output_tokens_merge = torch.cat(
235
+ (
236
+ output_tokens[:, :, : 1 * 400],
237
+ output_tokens[:, :, 400 * 3 : 4 * 400],
238
+ output_tokens[:, :, 400 * 1 : 2 * 400],
239
+ output_tokens[:, :, 400 * 4 : (total_frames) * 400],
240
+ ),
241
+ dim=2,
242
+ ).reshape(sample_num, -1, 400)
243
+
244
+ output_tokens_merge = torch.cat(
245
+ (output_tokens_merge, output_tokens[:, -1:, 400 * 2 : 3 * 400]), dim=1
246
+ )
247
+ duration /= 2
248
+ parent_given_tokens = output_tokens_merge
249
+
250
+ if args.both_stages:
251
+ move_start_time = time.time()
252
+ logging.debug("moving stage 2 model to cpu")
253
+ model = model.cpu()
254
+ torch.cuda.empty_cache()
255
+ logging.debug(
256
+ "moving out model2 takes time: {:.2f}".format(
257
+ time.time() - move_start_time
258
+ )
259
+ )
260
+
261
+ logging.info(
262
+ "CogVideo Stage2 completed. Taken time {:.2f}\n".format(
263
+ time.time() - stage2_starttime
264
+ )
265
+ )
266
+
267
+ # decoding
268
+ # imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
269
+ # os.makedirs(output_dir_full_path, exist_ok=True)
270
+ # my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False)
271
+ # torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt'))
272
+ # os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2")
273
+
274
+ # direct super-resolution by CogView2
275
+ logging.info("[Direct super-resolution]")
276
+ dsr_starttime = time.time()
277
+ enc_text = tokenizer.encode(seq_text)
278
+ frame_num_per_sample = parent_given_tokens.shape[1]
279
+ parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
280
+ text_seq = (
281
+ torch.cuda.LongTensor(enc_text, device=args.device)
282
+ .unsqueeze(0)
283
+ .repeat(parent_given_tokens_2d.shape[0], 1)
284
+ )
285
+ sred_tokens = dsr(text_seq, parent_given_tokens_2d)
286
+ decoded_sr_videos = []
287
+
288
+ for sample_i in range(sample_num):
289
+ decoded_sr_imgs = []
290
+ for frame_i in range(frame_num_per_sample):
291
+ decoded_sr_img = tokenizer.decode(
292
+ image_ids=sred_tokens[frame_i + sample_i * frame_num_per_sample][
293
+ -3600:
294
+ ]
295
+ )
296
+ decoded_sr_imgs.append(
297
+ torch.nn.functional.interpolate(
298
+ decoded_sr_img, size=(height, width)
299
+ )
300
+ )
301
+ decoded_sr_videos.append(decoded_sr_imgs)
302
+
303
+ return decoded_sr_videos
304
+ # for sample_i in range(sample_num):
305
+ # my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
306
+ # os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
307
+
308
+ # logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime))
309
+
310
+ # return True
311
+
312
+ def process_stage1(
313
+ model,
314
+ seq_text,
315
+ duration,
316
+ video_raw_text=None,
317
+ video_guidance_text="视频",
318
+ image_text_suffix="",
319
+ outputdir=None,
320
+ batch_size=1,
321
+ ):
322
+ process_start_time = time.time()
323
+ use_guide = args.use_guidance_stage1
324
+ if args.both_stages:
325
+ move_start_time = time.time()
326
+ logging.debug("moving stage 1 model to cuda")
327
+ model = model.cuda()
328
+ logging.debug(
329
+ "moving in model1 takes time: {:.2f}".format(
330
+ time.time() - move_start_time
331
+ )
332
+ )
333
+
334
+ if video_raw_text is None:
335
+ video_raw_text = seq_text
336
+ mbz = (
337
+ args.stage1_max_inference_batch_size
338
+ if args.stage1_max_inference_batch_size > 0
339
+ else args.max_inference_batch_size
340
+ )
341
+ assert batch_size < mbz or batch_size % mbz == 0
342
+ frame_len = 400
343
+
344
+ # generate the first frame:
345
+ enc_text = tokenizer.encode(seq_text + image_text_suffix)
346
+ seq_1st = (
347
+ enc_text + [tokenizer["<start_of_image>"]] + [-1] * 400
348
+ ) # IV!! # test local!!! # test randboi!!!
349
+ logging.info(
350
+ "[Generating First Frame with CogView2]Raw text: {:s}".format(
351
+ tokenizer.decode(enc_text)
352
+ )
353
+ )
354
+ text_len_1st = len(seq_1st) - frame_len * 1 - 1
355
+
356
+ seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
357
+ output_list_1st = []
358
+ for tim in range(max(batch_size // mbz, 1)):
359
+ start_time = time.time()
360
+ output_list_1st.append(
361
+ my_filling_sequence(
362
+ model,
363
+ args,
364
+ seq_1st.clone(),
365
+ batch_size=min(batch_size, mbz),
366
+ get_masks_and_position_ids=get_masks_and_position_ids_stage1,
367
+ text_len=text_len_1st,
368
+ frame_len=frame_len,
369
+ strategy=strategy_cogview2,
370
+ strategy2=strategy_cogvideo,
371
+ log_text_attention_weights=1.4,
372
+ enforce_no_swin=True,
373
+ mode_stage1=True,
374
+ )[0]
375
+ )
376
+ logging.info(
377
+ "[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)
378
+ )
379
+ output_tokens_1st = torch.cat(output_list_1st, dim=0)
380
+ given_tokens = output_tokens_1st[
381
+ :, text_len_1st + 1 : text_len_1st + 401
382
+ ].unsqueeze(
383
+ 1
384
+ ) # given_tokens.shape: [bs, frame_num, 400]
385
+
386
+ # generate subsequent frames:
387
+ total_frames = generate_frame_num
388
+ enc_duration = tokenizer.encode(str(float(duration)) + "秒")
389
+ if use_guide:
390
+ video_raw_text = video_raw_text + " 视频"
391
+ enc_text_video = tokenizer.encode(video_raw_text)
392
+ seq = (
393
+ enc_duration
394
+ + [tokenizer["<n>"]]
395
+ + enc_text_video
396
+ + [tokenizer["<start_of_image>"]]
397
+ + [-1] * 400 * generate_frame_num
398
+ )
399
+ guider_seq = (
400
+ enc_duration
401
+ + [tokenizer["<n>"]]
402
+ + tokenizer.encode(video_guidance_text)
403
+ + [tokenizer["<start_of_image>"]]
404
+ + [-1] * 400 * generate_frame_num
405
+ )
406
+ logging.info(
407
+ "[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(
408
+ 4 / duration, tokenizer.decode(enc_text_video)
409
+ )
410
+ )
411
+
412
+ text_len = len(seq) - frame_len * generate_frame_num - 1
413
+ guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
414
+ seq = (
415
+ torch.cuda.LongTensor(seq, device=args.device)
416
+ .unsqueeze(0)
417
+ .repeat(batch_size, 1)
418
+ )
419
+ guider_seq = (
420
+ torch.cuda.LongTensor(guider_seq, device=args.device)
421
+ .unsqueeze(0)
422
+ .repeat(batch_size, 1)
423
+ )
424
+
425
+ for given_frame_id in range(given_tokens.shape[1]):
426
+ seq[
427
+ :,
428
+ text_len
429
+ + 1
430
+ + given_frame_id * 400 : text_len
431
+ + 1
432
+ + (given_frame_id + 1) * 400,
433
+ ] = given_tokens[:, given_frame_id]
434
+ guider_seq[
435
+ :,
436
+ guider_text_len
437
+ + 1
438
+ + given_frame_id * 400 : guider_text_len
439
+ + 1
440
+ + (given_frame_id + 1) * 400,
441
+ ] = given_tokens[:, given_frame_id]
442
+ output_list = []
443
+
444
+ if use_guide:
445
+ video_log_text_attention_weights = 0
446
+ else:
447
+ guider_seq = None
448
+ video_log_text_attention_weights = 1.4
449
+
450
+ for tim in range(max(batch_size // mbz, 1)):
451
+ start_time = time.time()
452
+ input_seq = (
453
+ seq[: min(batch_size, mbz)].clone()
454
+ if tim == 0
455
+ else seq[mbz * tim : mbz * (tim + 1)].clone()
456
+ )
457
+ guider_seq2 = (
458
+ (
459
+ guider_seq[: min(batch_size, mbz)].clone()
460
+ if tim == 0
461
+ else guider_seq[mbz * tim : mbz * (tim + 1)].clone()
462
+ )
463
+ if guider_seq is not None
464
+ else None
465
+ )
466
+ output_list.append(
467
+ my_filling_sequence(
468
+ model,
469
+ args,
470
+ input_seq,
471
+ batch_size=min(batch_size, mbz),
472
+ get_masks_and_position_ids=get_masks_and_position_ids_stage1,
473
+ text_len=text_len,
474
+ frame_len=frame_len,
475
+ strategy=strategy_cogview2,
476
+ strategy2=strategy_cogvideo,
477
+ log_text_attention_weights=video_log_text_attention_weights,
478
+ guider_seq=guider_seq2,
479
+ guider_text_len=guider_text_len,
480
+ guidance_alpha=args.guidance_alpha,
481
+ limited_spatial_channel_mem=True,
482
+ mode_stage1=True,
483
+ )[0]
484
+ )
485
+
486
+ output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len :]
487
+
488
+ if args.both_stages:
489
+ move_start_time = time.time()
490
+ logging.debug("moving stage 1 model to cpu")
491
+ model = model.cpu()
492
+ torch.cuda.empty_cache()
493
+ logging.debug(
494
+ "moving in model1 takes time: {:.2f}".format(
495
+ time.time() - move_start_time
496
+ )
497
+ )
498
+
499
+ # decoding
500
+ imgs, sred_imgs, txts = [], [], []
501
+ for seq in output_tokens:
502
+ decoded_imgs = [
503
+ torch.nn.functional.interpolate(
504
+ tokenizer.decode(image_ids=seq.tolist()[i * 400 : (i + 1) * 400]),
505
+ size=(height, width),
506
+ )
507
+ for i in range(total_frames)
508
+ ]
509
+ imgs.append(decoded_imgs) # only the last image (target)
510
+
511
+ assert len(imgs) == batch_size
512
+ return imgs
513
+ # save_tokens = output_tokens[:, :+total_frames*400].reshape(-1, total_frames, 400).cpu()
514
+ # if outputdir is not None:
515
+ # for clip_i in range(len(imgs)):
516
+ # # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True)
517
+ # my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False)
518
+ # os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25")
519
+ # torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt'))
520
+
521
+ # logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time))
522
+
523
+ # return save_tokens
524
+
525
+ # ======================================================================================================
526
+
527
+ if args.stage_1 or args.both_stages:
528
+ if args.input_source != "interactive":
529
+ with open(args.input_source, "r") as fin:
530
+ promptlist = fin.readlines()
531
+ promptlist = [p.strip() for p in promptlist]
532
+ else:
533
+ promptlist = None
534
+
535
+ now_qi = -1
536
+ while True:
537
+ now_qi += 1
538
+
539
+ if promptlist is not None: # with input-source
540
+ if args.multi_gpu:
541
+ if now_qi % dist.get_world_size() != dist.get_rank():
542
+ continue
543
+ rk = dist.get_rank()
544
+ else:
545
+ rk = 0
546
+ raw_text = promptlist[now_qi]
547
+ raw_text = raw_text.strip()
548
+ print(f"Working on Line No. {now_qi} on {rk}... [{raw_text}]")
549
+ else: # interactive
550
+ raw_text = input("\nPlease Input Query (stop to exit) >>> ")
551
+ raw_text = raw_text.strip()
552
+ if not raw_text:
553
+ print("Query should not be empty!")
554
+ continue
555
+ if raw_text == "stop":
556
+ return
557
+
558
+ try:
559
+ path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
560
+ parent_given_tokens, imgs = process_stage1(
561
+ model_stage1,
562
+ raw_text,
563
+ duration=4.0,
564
+ video_raw_text=raw_text,
565
+ video_guidance_text="视频",
566
+ image_text_suffix=" 高清摄影",
567
+ outputdir=path if args.stage_1 else None,
568
+ batch_size=args.batch_size,
569
+ )
570
+ if args.stage_1 and not args.both_stages:
571
+ print("only stage 1")
572
+ return imgs
573
+
574
+ if args.both_stages:
575
+ videos = process_stage2(
576
+ model_stage2,
577
+ raw_text,
578
+ duration=duration,
579
+ video_raw_text=raw_text + " 视频",
580
+ video_guidance_text="视频",
581
+ parent_given_tokens=parent_given_tokens,
582
+ outputdir=path,
583
+ gpu_rank=0,
584
+ gpu_parallel_size=1,
585
+ ) # TODO: 修改
586
+ return videos
587
+ except (ValueError, FileNotFoundError) as e:
588
+ print(e)
589
+ continue
590
+
591
+ elif args.stage_2:
592
+ sample_dirs = os.listdir(args.output_path)
593
+ for sample in sample_dirs:
594
+ raw_text = sample.split("_")[-1]
595
+ path = os.path.join(args.output_path, sample, "Interp")
596
+ parent_given_tokens = torch.load(
597
+ os.path.join(args.output_path, sample, "frame_tokens.pt")
598
+ )
599
+
600
+ process_stage2(
601
+ raw_text,
602
+ duration=2.0,
603
+ video_raw_text=raw_text + " 视频",
604
+ video_guidance_text="视频",
605
+ parent_given_tokens=parent_given_tokens,
606
+ outputdir=path,
607
+ gpu_rank=0,
608
+ gpu_parallel_size=1,
609
+ ) # TODO: 修改
610
+
611
+ else:
612
+ assert False
src/videogen_hub/pipelines/cogvideo/cogvideo_src/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
src/videogen_hub/pipelines/cogvideo/cogvideo_src/Model_License ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The CogVideo License
2
+
3
+ Section I: PREAMBLE
4
+
5
+ Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
6
+
7
+ Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
8
+
9
+ In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
10
+
11
+ Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
12
+
13
+ This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
14
+
15
+ NOW THEREFORE, You and Licensor agree as follows:
16
+
17
+ 1. Definitions
18
+
19
+ - "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
20
+ - "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
21
+ - "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
22
+ - "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
23
+ - "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
24
+ - "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
25
+ - "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
26
+ - "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
27
+ - "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
28
+ - "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
29
+ - "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
30
+ - "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
31
+
32
+ Section II: INTELLECTUAL PROPERTY RIGHTS
33
+
34
+ Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
35
+
36
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
37
+ 3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
38
+
39
+ Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
40
+
41
+ 4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
42
+ Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
43
+ You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
44
+ You must cause any modified files to carry prominent notices stating that You changed the files;
45
+ You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
46
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
47
+ 5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
48
+ 6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
49
+
50
+ Section IV: OTHER PROVISIONS
51
+
52
+ 7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
53
+ 8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
54
+ 9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
55
+ 10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
56
+ 11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
57
+ 12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
58
+
59
+ END OF TERMS AND CONDITIONS
60
+
61
+
62
+
63
+
64
+ Attachment A
65
+
66
+ Use Restrictions
67
+
68
+ You agree not to use the Model or Derivatives of the Model:
69
+ - In any way that violates any applicable national, federal, state, local or international law or regulation;
70
+ - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
71
+ - To generate or disseminate verifiably false information and/or content with the purpose of harming others;
72
+ - To generate or disseminate personal identifiable information that can be used to harm an individual;
73
+ - To defame, disparage or otherwise harass others;
74
+ - For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
75
+ - For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
76
+ - To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
77
+ - For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
78
+ - To provide medical advice and medical results interpretation;
79
+ - To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
src/videogen_hub/pipelines/cogvideo/cogvideo_src/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/cogvideo/cogvideo_src/cluster_label2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b87880fdbe89670f12844377b9cf97a9733b1f54e3a9b73cbb9835084c4e02ec
3
+ size 160128
src/videogen_hub/pipelines/cogvideo/cogvideo_src/coglm_strategy.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : coglm_strategy.py
4
+ @Time : 2021/10/08 22:22:42
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+ import numpy as np
16
+ import torch.nn.functional as F
17
+
18
+
19
+ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504):
20
+ # This function has been mostly taken from huggingface conversational ai code at
21
+ # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
22
+
23
+ if top_k > 0:
24
+ # Remove all tokens with a probability less than the last token of the top-k
25
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
26
+ logits[indices_to_remove] = filter_value
27
+
28
+ if top_p > 0.0:
29
+ # convert to 1D
30
+ logits = logits.view(logits.size()[1]).contiguous()
31
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
32
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
33
+
34
+ # Remove tokens with cumulative probability above the threshold
35
+ sorted_indices_to_remove = cumulative_probs > top_p
36
+ # Shift the indices to the right to keep also the first token above the threshold
37
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
38
+ sorted_indices_to_remove[..., 0] = 0
39
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
40
+ logits[indices_to_remove] = filter_value
41
+ # going back to 2D
42
+ logits = logits.view(1, -1).contiguous()
43
+
44
+ return logits
45
+
46
+
47
+ class CoglmStrategy:
48
+ def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, temperature2=0.89):
49
+ self.invalid_slices = invalid_slices
50
+ self.temperature = temperature
51
+ self.temperature2 = temperature2
52
+ self.topk = top_k
53
+ self.top_p = top_p
54
+ self.eps = eps
55
+ if end_tokens is None:
56
+ end_tokens = []
57
+ self.end_tokens = end_tokens
58
+ self._is_done = False
59
+ self.outlier_count_down = torch.zeros(16)
60
+ self.vis_list = [[]for i in range(16)]
61
+ self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
62
+ self.start_pos = -1
63
+ self.white_cluster = []
64
+ # self.fout = open('tmp.txt', 'w')
65
+
66
+ @property
67
+ def is_done(self) -> bool:
68
+ return self._is_done
69
+
70
+ def forward(self, logits, tokens, mems, temperature=None, temperature2=None):
71
+ if temperature is None:
72
+ temperature = self.temperature
73
+ if temperature2 is None:
74
+ temperature2 = self.temperature2
75
+ logits = logits / temperature
76
+ for invalid_slice in self.invalid_slices:
77
+ logits[..., invalid_slice] = -65504
78
+
79
+ rprobs = F.softmax(logits.float(), dim=-1)
80
+ c = self.cluster_labels.expand(*rprobs.shape)
81
+ cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
82
+ # self.fout.write(str(tokens.shape[-1])+ ' ' + str(cprobs.topk(10)) + '\n')
83
+ # self.fout.flush()
84
+ best_scores, best_clusters = cprobs.topk(self.topk)
85
+ bz = logits.shape[0]
86
+ for i in range(bz):
87
+ selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
88
+ logits[i, self.cluster_labels != selected_cluster] = -65504
89
+
90
+ # logits = top_k_logits(logits, self.topk, self.top_p)
91
+ probs = F.softmax(logits.float()/temperature2, dim=-1) # float is essetial, due to a bug in Pytorch
92
+ pred = torch.multinomial(probs, num_samples=1)
93
+
94
+ if pred.numel() == 1 and pred.item() in self.end_tokens:
95
+ self._is_done = True
96
+ tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
97
+ return tokens, mems
98
+
99
+ def finalize(self, tokens, mems):
100
+ self._is_done = False
101
+ return tokens, mems
src/videogen_hub/pipelines/cogvideo/cogvideo_src/cogvideo_pipeline.py ADDED
@@ -0,0 +1,1341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ """
3
+ @File : cogvideo_pipeline.py
4
+ @Time : 2022/07/15 11:24:56
5
+ @Author : Wenyi Hong
6
+ @Version : 1.0
7
+ @Contact : [email protected]
8
+ """
9
+
10
+ # here put the import lib
11
+
12
+ import os
13
+ import sys
14
+ import torch
15
+ import argparse
16
+ import time
17
+ from torchvision.utils import save_image
18
+ import stat
19
+ from videogen_hub.depend.icetk import icetk as tokenizer
20
+ import logging, sys
21
+
22
+ import torch.distributed as dist
23
+
24
+ tokenizer.add_special_tokens(
25
+ ["<start_of_image>", "<start_of_english>", "<start_of_chinese>"]
26
+ )
27
+
28
+
29
+ from SwissArmyTransformer import get_args
30
+ from SwissArmyTransformer.data_utils import BinaryDataset, make_loaders
31
+ from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
32
+ from SwissArmyTransformer.generation.utils import (
33
+ timed_name,
34
+ save_multiple_images,
35
+ generate_continually,
36
+ )
37
+ from SwissArmyTransformer.resources import auto_create
38
+
39
+ from .models.cogvideo_cache_model import CogVideoCacheModel
40
+ from .coglm_strategy import CoglmStrategy
41
+
42
+
43
+ def get_masks_and_position_ids_stage1(data, textlen, framelen):
44
+ # Extract batch size and sequence length.
45
+ tokens = data
46
+ seq_length = len(data[0])
47
+ # Attention mask (lower triangular).
48
+ attention_mask = torch.ones(
49
+ (1, textlen + framelen, textlen + framelen), device=data.device
50
+ )
51
+ attention_mask[:, :textlen, textlen:] = 0
52
+ attention_mask[:, textlen:, textlen:].tril_()
53
+ attention_mask.unsqueeze_(1)
54
+ # Unaligned version
55
+ position_ids = torch.zeros(seq_length, dtype=torch.long, device=data.device)
56
+ torch.arange(
57
+ textlen, out=position_ids[:textlen], dtype=torch.long, device=data.device
58
+ )
59
+ torch.arange(
60
+ 512,
61
+ 512 + seq_length - textlen,
62
+ out=position_ids[textlen:],
63
+ dtype=torch.long,
64
+ device=data.device,
65
+ )
66
+ position_ids = position_ids.unsqueeze(0)
67
+
68
+ return tokens, attention_mask, position_ids
69
+
70
+
71
+ def get_masks_and_position_ids_stage2(data, textlen, framelen):
72
+ # Extract batch size and sequence length.
73
+ tokens = data
74
+ seq_length = len(data[0])
75
+
76
+ # Attention mask (lower triangular).
77
+ attention_mask = torch.ones(
78
+ (1, textlen + framelen, textlen + framelen), device=data.device
79
+ )
80
+ attention_mask[:, :textlen, textlen:] = 0
81
+ attention_mask[:, textlen:, textlen:].tril_()
82
+ attention_mask.unsqueeze_(1)
83
+
84
+ # Unaligned version
85
+ position_ids = torch.zeros(seq_length, dtype=torch.long, device=data.device)
86
+ torch.arange(
87
+ textlen, out=position_ids[:textlen], dtype=torch.long, device=data.device
88
+ )
89
+ frame_num = (seq_length - textlen) // framelen
90
+ assert frame_num == 5
91
+ torch.arange(
92
+ 512,
93
+ 512 + framelen,
94
+ out=position_ids[textlen : textlen + framelen],
95
+ dtype=torch.long,
96
+ device=data.device,
97
+ )
98
+ torch.arange(
99
+ 512 + framelen * 2,
100
+ 512 + framelen * 3,
101
+ out=position_ids[textlen + framelen : textlen + framelen * 2],
102
+ dtype=torch.long,
103
+ device=data.device,
104
+ )
105
+ torch.arange(
106
+ 512 + framelen * (frame_num - 1),
107
+ 512 + framelen * frame_num,
108
+ out=position_ids[textlen + framelen * 2 : textlen + framelen * 3],
109
+ dtype=torch.long,
110
+ device=data.device,
111
+ )
112
+ torch.arange(
113
+ 512 + framelen * 1,
114
+ 512 + framelen * 2,
115
+ out=position_ids[textlen + framelen * 3 : textlen + framelen * 4],
116
+ dtype=torch.long,
117
+ device=data.device,
118
+ )
119
+ torch.arange(
120
+ 512 + framelen * 3,
121
+ 512 + framelen * 4,
122
+ out=position_ids[textlen + framelen * 4 : textlen + framelen * 5],
123
+ dtype=torch.long,
124
+ device=data.device,
125
+ )
126
+
127
+ position_ids = position_ids.unsqueeze(0)
128
+
129
+ return tokens, attention_mask, position_ids
130
+
131
+
132
+ def my_update_mems(
133
+ hiddens, mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len
134
+ ):
135
+ if hiddens is None:
136
+ return None, mems_indexs
137
+ mem_num = len(hiddens)
138
+ ret_mem = []
139
+ with torch.no_grad():
140
+ for id in range(mem_num):
141
+ if hiddens[id][0] is None:
142
+ ret_mem.append(None)
143
+ else:
144
+ if (
145
+ id == 0
146
+ and limited_spatial_channel_mem
147
+ and mems_indexs[id] + hiddens[0][0].shape[1] >= text_len + frame_len
148
+ ):
149
+ if mems_indexs[id] == 0:
150
+ for layer, hidden in enumerate(hiddens[id]):
151
+ mems_buffers[id][layer, :, :text_len] = hidden.expand(
152
+ mems_buffers[id].shape[1], -1, -1
153
+ )[:, :text_len]
154
+ new_mem_len_part2 = (
155
+ mems_indexs[id] + hiddens[0][0].shape[1] - text_len
156
+ ) % frame_len
157
+ if new_mem_len_part2 > 0:
158
+ for layer, hidden in enumerate(hiddens[id]):
159
+ mems_buffers[id][
160
+ layer, :, text_len : text_len + new_mem_len_part2
161
+ ] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[
162
+ :, -new_mem_len_part2:
163
+ ]
164
+ mems_indexs[id] = text_len + new_mem_len_part2
165
+ else:
166
+ for layer, hidden in enumerate(hiddens[id]):
167
+ mems_buffers[id][
168
+ layer,
169
+ :,
170
+ mems_indexs[id] : mems_indexs[id] + hidden.shape[1],
171
+ ] = hidden.expand(mems_buffers[id].shape[1], -1, -1)
172
+ mems_indexs[id] += hidden.shape[1]
173
+ ret_mem.append(mems_buffers[id][:, :, : mems_indexs[id]])
174
+ return ret_mem, mems_indexs
175
+
176
+
177
+ def my_save_multiple_images(imgs, path, subdir, debug=True):
178
+ # imgs: list of tensor images
179
+ if debug:
180
+ imgs = torch.cat(imgs, dim=0)
181
+ print("\nSave to: ", path, flush=True)
182
+ save_image(imgs, path, normalize=True)
183
+ else:
184
+ print("\nSave to: ", path, flush=True)
185
+ single_frame_path = os.path.join(path, subdir)
186
+ os.makedirs(single_frame_path, exist_ok=True)
187
+ for i in range(len(imgs)):
188
+ save_image(
189
+ imgs[i],
190
+ os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'),
191
+ normalize=True,
192
+ )
193
+ os.chmod(
194
+ os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'),
195
+ stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU,
196
+ )
197
+ save_image(
198
+ torch.cat(imgs, dim=0),
199
+ os.path.join(single_frame_path, f"frame_concat.jpg"),
200
+ normalize=True,
201
+ )
202
+ os.chmod(
203
+ os.path.join(single_frame_path, f"frame_concat.jpg"),
204
+ stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU,
205
+ )
206
+
207
+
208
+ def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
209
+ # The fisrt token's position id of the frame that the next token belongs to;
210
+ if total_len < text_len:
211
+ return None
212
+ return (total_len - text_len) // frame_len * frame_len + text_len
213
+
214
+
215
+ def my_filling_sequence(
216
+ model,
217
+ args,
218
+ seq,
219
+ batch_size,
220
+ get_masks_and_position_ids,
221
+ text_len,
222
+ frame_len,
223
+ strategy=BaseStrategy(),
224
+ strategy2=BaseStrategy(),
225
+ mems=None,
226
+ log_text_attention_weights=0, # default to 0: no artificial change
227
+ mode_stage1=True,
228
+ enforce_no_swin=False,
229
+ guider_seq=None,
230
+ guider_text_len=0,
231
+ guidance_alpha=1,
232
+ limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
233
+ **kw_args,
234
+ ):
235
+ """
236
+ seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
237
+ mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
238
+ cache, should be first mems.shape[1] parts of context_tokens.
239
+ mems are the first-level citizens here, but we don't assume what is memorized.
240
+ input mems are used when multi-phase generation.
241
+ """
242
+ if guider_seq is not None:
243
+ logging.debug("Using Guidance In Inference")
244
+ if limited_spatial_channel_mem:
245
+ logging.debug("Limit spatial-channel's mem to current frame")
246
+ assert len(seq.shape) == 2
247
+
248
+ # building the initial tokens, attention_mask, and position_ids
249
+ actual_context_length = 0
250
+
251
+ while seq[-1][actual_context_length] >= 0: # the last seq has least given tokens
252
+ actual_context_length += 1 # [0, context_length-1] are given
253
+ assert actual_context_length > 0
254
+ current_frame_num = (actual_context_length - text_len) // frame_len
255
+ assert current_frame_num >= 0
256
+ context_length = text_len + current_frame_num * frame_len
257
+
258
+ tokens, attention_mask, position_ids = get_masks_and_position_ids(
259
+ seq, text_len, frame_len
260
+ )
261
+ tokens = tokens[..., :context_length]
262
+ input_tokens = tokens.clone()
263
+
264
+ if guider_seq is not None:
265
+ guider_index_delta = text_len - guider_text_len
266
+ guider_tokens, guider_attention_mask, guider_position_ids = (
267
+ get_masks_and_position_ids(guider_seq, guider_text_len, frame_len)
268
+ )
269
+ guider_tokens = guider_tokens[..., : context_length - guider_index_delta]
270
+ guider_input_tokens = guider_tokens.clone()
271
+
272
+ for fid in range(current_frame_num):
273
+ input_tokens[:, text_len + 400 * fid] = tokenizer["<start_of_image>"]
274
+ if guider_seq is not None:
275
+ guider_input_tokens[:, guider_text_len + 400 * fid] = tokenizer[
276
+ "<start_of_image>"
277
+ ]
278
+
279
+ attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
280
+ # initialize generation
281
+ counter = context_length - 1 # Last fixed index is ``counter''
282
+ index = 0 # Next forward starting index, also the length of cache.
283
+ mems_buffers_on_GPU = False
284
+ mems_indexs = [0, 0]
285
+ mems_len = [
286
+ (400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74,
287
+ 5 * 400 + 74,
288
+ ]
289
+ mems_buffers = [
290
+ torch.zeros(
291
+ args.num_layers,
292
+ batch_size,
293
+ mem_len,
294
+ args.hidden_size * 2,
295
+ dtype=next(model.parameters()).dtype,
296
+ )
297
+ for mem_len in mems_len
298
+ ]
299
+
300
+ if guider_seq is not None:
301
+ guider_attention_mask = guider_attention_mask.type_as(
302
+ next(model.parameters())
303
+ ) # if fp16
304
+ guider_mems_buffers = [
305
+ torch.zeros(
306
+ args.num_layers,
307
+ batch_size,
308
+ mem_len,
309
+ args.hidden_size * 2,
310
+ dtype=next(model.parameters()).dtype,
311
+ )
312
+ for mem_len in mems_len
313
+ ]
314
+ guider_mems_indexs = [0, 0]
315
+ guider_mems = None
316
+
317
+ torch.cuda.empty_cache()
318
+ # step-by-step generation
319
+ while counter < len(seq[0]) - 1:
320
+ # we have generated counter+1 tokens
321
+ # Now, we want to generate seq[counter + 1],
322
+ # token[:, index: counter+1] needs forwarding.
323
+ if index == 0:
324
+ group_size = (
325
+ 2
326
+ if (input_tokens.shape[0] == batch_size and not mode_stage1)
327
+ else batch_size
328
+ )
329
+
330
+ logits_all = None
331
+ for batch_idx in range(0, input_tokens.shape[0], group_size):
332
+ logits, *output_per_layers = model(
333
+ input_tokens[batch_idx : batch_idx + group_size, index:],
334
+ position_ids[..., index : counter + 1],
335
+ attention_mask, # TODO memlen
336
+ mems=mems,
337
+ text_len=text_len,
338
+ frame_len=frame_len,
339
+ counter=counter,
340
+ log_text_attention_weights=log_text_attention_weights,
341
+ enforce_no_swin=enforce_no_swin,
342
+ **kw_args,
343
+ )
344
+ logits_all = (
345
+ torch.cat((logits_all, logits), dim=0)
346
+ if logits_all is not None
347
+ else logits
348
+ )
349
+ mem_kv01 = [
350
+ [o["mem_kv"][0] for o in output_per_layers],
351
+ [o["mem_kv"][1] for o in output_per_layers],
352
+ ]
353
+ next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
354
+ text_len, frame_len, mem_kv01[0][0].shape[1]
355
+ )
356
+ for id, mem_kv in enumerate(mem_kv01):
357
+ for layer, mem_kv_perlayer in enumerate(mem_kv):
358
+ if limited_spatial_channel_mem and id == 0:
359
+ mems_buffers[id][
360
+ layer, batch_idx : batch_idx + group_size, :text_len
361
+ ] = mem_kv_perlayer.expand(
362
+ min(group_size, input_tokens.shape[0] - batch_idx),
363
+ -1,
364
+ -1,
365
+ )[
366
+ :, :text_len
367
+ ]
368
+ mems_buffers[id][
369
+ layer,
370
+ batch_idx : batch_idx + group_size,
371
+ text_len : text_len
372
+ + mem_kv_perlayer.shape[1]
373
+ - next_tokens_frame_begin_id,
374
+ ] = mem_kv_perlayer.expand(
375
+ min(group_size, input_tokens.shape[0] - batch_idx),
376
+ -1,
377
+ -1,
378
+ )[
379
+ :, next_tokens_frame_begin_id:
380
+ ]
381
+ else:
382
+ mems_buffers[id][
383
+ layer,
384
+ batch_idx : batch_idx + group_size,
385
+ : mem_kv_perlayer.shape[1],
386
+ ] = mem_kv_perlayer.expand(
387
+ min(group_size, input_tokens.shape[0] - batch_idx),
388
+ -1,
389
+ -1,
390
+ )
391
+ mems_indexs[0], mems_indexs[1] = (
392
+ mem_kv01[0][0].shape[1],
393
+ mem_kv01[1][0].shape[1],
394
+ )
395
+ if limited_spatial_channel_mem:
396
+ mems_indexs[0] -= next_tokens_frame_begin_id - text_len
397
+
398
+ mems = [mems_buffers[id][:, :, : mems_indexs[id]] for id in range(2)]
399
+ logits = logits_all
400
+
401
+ # Guider
402
+ if guider_seq is not None:
403
+ guider_logits_all = None
404
+ for batch_idx in range(0, guider_input_tokens.shape[0], group_size):
405
+ guider_logits, *guider_output_per_layers = model(
406
+ guider_input_tokens[
407
+ batch_idx : batch_idx + group_size,
408
+ max(index - guider_index_delta, 0) :,
409
+ ],
410
+ guider_position_ids[
411
+ ...,
412
+ max(index - guider_index_delta, 0) : counter
413
+ + 1
414
+ - guider_index_delta,
415
+ ],
416
+ guider_attention_mask,
417
+ mems=guider_mems,
418
+ text_len=guider_text_len,
419
+ frame_len=frame_len,
420
+ counter=counter - guider_index_delta,
421
+ log_text_attention_weights=log_text_attention_weights,
422
+ enforce_no_swin=enforce_no_swin,
423
+ **kw_args,
424
+ )
425
+ guider_logits_all = (
426
+ torch.cat((guider_logits_all, guider_logits), dim=0)
427
+ if guider_logits_all is not None
428
+ else guider_logits
429
+ )
430
+ guider_mem_kv01 = [
431
+ [o["mem_kv"][0] for o in guider_output_per_layers],
432
+ [o["mem_kv"][1] for o in guider_output_per_layers],
433
+ ]
434
+ for id, guider_mem_kv in enumerate(guider_mem_kv01):
435
+ for layer, guider_mem_kv_perlayer in enumerate(guider_mem_kv):
436
+ if limited_spatial_channel_mem and id == 0:
437
+ guider_mems_buffers[id][
438
+ layer,
439
+ batch_idx : batch_idx + group_size,
440
+ :guider_text_len,
441
+ ] = guider_mem_kv_perlayer.expand(
442
+ min(group_size, input_tokens.shape[0] - batch_idx),
443
+ -1,
444
+ -1,
445
+ )[
446
+ :, :guider_text_len
447
+ ]
448
+ guider_next_tokens_frame_begin_id = (
449
+ calc_next_tokens_frame_begin_id(
450
+ guider_text_len,
451
+ frame_len,
452
+ guider_mem_kv_perlayer.shape[1],
453
+ )
454
+ )
455
+ guider_mems_buffers[id][
456
+ layer,
457
+ batch_idx : batch_idx + group_size,
458
+ guider_text_len : guider_text_len
459
+ + guider_mem_kv_perlayer.shape[1]
460
+ - guider_next_tokens_frame_begin_id,
461
+ ] = guider_mem_kv_perlayer.expand(
462
+ min(group_size, input_tokens.shape[0] - batch_idx),
463
+ -1,
464
+ -1,
465
+ )[
466
+ :, guider_next_tokens_frame_begin_id:
467
+ ]
468
+ else:
469
+ guider_mems_buffers[id][
470
+ layer,
471
+ batch_idx : batch_idx + group_size,
472
+ : guider_mem_kv_perlayer.shape[1],
473
+ ] = guider_mem_kv_perlayer.expand(
474
+ min(group_size, input_tokens.shape[0] - batch_idx),
475
+ -1,
476
+ -1,
477
+ )
478
+ guider_mems_indexs[0], guider_mems_indexs[1] = (
479
+ guider_mem_kv01[0][0].shape[1],
480
+ guider_mem_kv01[1][0].shape[1],
481
+ )
482
+ if limited_spatial_channel_mem:
483
+ guider_mems_indexs[0] -= (
484
+ guider_next_tokens_frame_begin_id - guider_text_len
485
+ )
486
+ guider_mems = [
487
+ guider_mems_buffers[id][:, :, : guider_mems_indexs[id]]
488
+ for id in range(2)
489
+ ]
490
+ guider_logits = guider_logits_all
491
+ else:
492
+ if not mems_buffers_on_GPU:
493
+ if not mode_stage1:
494
+ torch.cuda.empty_cache()
495
+ for idx, mem in enumerate(mems):
496
+ mems[idx] = mem.to(next(model.parameters()).device)
497
+ if guider_seq is not None:
498
+ for idx, mem in enumerate(guider_mems):
499
+ guider_mems[idx] = mem.to(next(model.parameters()).device)
500
+ else:
501
+ torch.cuda.empty_cache()
502
+ for idx, mem_buffer in enumerate(mems_buffers):
503
+ mems_buffers[idx] = mem_buffer.to(
504
+ next(model.parameters()).device
505
+ )
506
+ mems = [
507
+ mems_buffers[id][:, :, : mems_indexs[id]] for id in range(2)
508
+ ]
509
+ if guider_seq is not None:
510
+ for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
511
+ guider_mems_buffers[idx] = guider_mem_buffer.to(
512
+ next(model.parameters()).device
513
+ )
514
+ guider_mems = [
515
+ guider_mems_buffers[id][:, :, : guider_mems_indexs[id]]
516
+ for id in range(2)
517
+ ]
518
+ mems_buffers_on_GPU = True
519
+
520
+ logits, *output_per_layers = model(
521
+ input_tokens[:, index:],
522
+ position_ids[..., index : counter + 1],
523
+ attention_mask, # TODO memlen
524
+ mems=mems,
525
+ text_len=text_len,
526
+ frame_len=frame_len,
527
+ counter=counter,
528
+ log_text_attention_weights=log_text_attention_weights,
529
+ enforce_no_swin=enforce_no_swin,
530
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
531
+ **kw_args,
532
+ )
533
+ mem_kv0, mem_kv1 = [o["mem_kv"][0] for o in output_per_layers], [
534
+ o["mem_kv"][1] for o in output_per_layers
535
+ ]
536
+
537
+ if guider_seq is not None:
538
+ guider_logits, *guider_output_per_layers = model(
539
+ guider_input_tokens[:, max(index - guider_index_delta, 0) :],
540
+ guider_position_ids[
541
+ ...,
542
+ max(index - guider_index_delta, 0) : counter
543
+ + 1
544
+ - guider_index_delta,
545
+ ],
546
+ guider_attention_mask,
547
+ mems=guider_mems,
548
+ text_len=guider_text_len,
549
+ frame_len=frame_len,
550
+ counter=counter - guider_index_delta,
551
+ log_text_attention_weights=0,
552
+ enforce_no_swin=enforce_no_swin,
553
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
554
+ **kw_args,
555
+ )
556
+ guider_mem_kv0, guider_mem_kv1 = [
557
+ o["mem_kv"][0] for o in guider_output_per_layers
558
+ ], [o["mem_kv"][1] for o in guider_output_per_layers]
559
+
560
+ if not mems_buffers_on_GPU:
561
+ torch.cuda.empty_cache()
562
+ for idx, mem_buffer in enumerate(mems_buffers):
563
+ mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
564
+ if guider_seq is not None:
565
+ for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
566
+ guider_mems_buffers[idx] = guider_mem_buffer.to(
567
+ next(model.parameters()).device
568
+ )
569
+ mems_buffers_on_GPU = True
570
+
571
+ mems, mems_indexs = my_update_mems(
572
+ [mem_kv0, mem_kv1],
573
+ mems_buffers,
574
+ mems_indexs,
575
+ limited_spatial_channel_mem,
576
+ text_len,
577
+ frame_len,
578
+ )
579
+ if guider_seq is not None:
580
+ guider_mems, guider_mems_indexs = my_update_mems(
581
+ [guider_mem_kv0, guider_mem_kv1],
582
+ guider_mems_buffers,
583
+ guider_mems_indexs,
584
+ limited_spatial_channel_mem,
585
+ guider_text_len,
586
+ frame_len,
587
+ )
588
+
589
+ counter += 1
590
+ index = counter
591
+
592
+ logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
593
+ tokens = tokens.expand(batch_size, -1)
594
+ if guider_seq is not None:
595
+ guider_logits = guider_logits[:, -1].expand(batch_size, -1)
596
+ guider_tokens = guider_tokens.expand(batch_size, -1)
597
+
598
+ if seq[-1][counter].item() < 0:
599
+ # sampling
600
+ guided_logits = (
601
+ guider_logits + (logits - guider_logits) * guidance_alpha
602
+ if guider_seq is not None
603
+ else logits
604
+ )
605
+ if mode_stage1 and counter < text_len + 400:
606
+ tokens, mems = strategy.forward(guided_logits, tokens, mems)
607
+ else:
608
+ tokens, mems = strategy2.forward(guided_logits, tokens, mems)
609
+ if guider_seq is not None:
610
+ guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]), dim=1)
611
+
612
+ if seq[0][counter].item() >= 0:
613
+ for si in range(seq.shape[0]):
614
+ if seq[si][counter].item() >= 0:
615
+ tokens[si, -1] = seq[si, counter]
616
+ if guider_seq is not None:
617
+ guider_tokens[si, -1] = guider_seq[
618
+ si, counter - guider_index_delta
619
+ ]
620
+
621
+ else:
622
+ tokens = torch.cat(
623
+ (
624
+ tokens,
625
+ seq[:, counter : counter + 1]
626
+ .clone()
627
+ .expand(tokens.shape[0], 1)
628
+ .to(device=tokens.device, dtype=tokens.dtype),
629
+ ),
630
+ dim=1,
631
+ )
632
+ if guider_seq is not None:
633
+ guider_tokens = torch.cat(
634
+ (
635
+ guider_tokens,
636
+ guider_seq[
637
+ :,
638
+ counter
639
+ - guider_index_delta : counter
640
+ + 1
641
+ - guider_index_delta,
642
+ ]
643
+ .clone()
644
+ .expand(guider_tokens.shape[0], 1)
645
+ .to(device=guider_tokens.device, dtype=guider_tokens.dtype),
646
+ ),
647
+ dim=1,
648
+ )
649
+
650
+ input_tokens = tokens.clone()
651
+ if guider_seq is not None:
652
+ guider_input_tokens = guider_tokens.clone()
653
+ if (index - text_len - 1) // 400 < (
654
+ input_tokens.shape[-1] - text_len - 1
655
+ ) // 400:
656
+ boi_idx = ((index - text_len - 1) // 400 + 1) * 400 + text_len
657
+ while boi_idx < input_tokens.shape[-1]:
658
+ input_tokens[:, boi_idx] = tokenizer["<start_of_image>"]
659
+ if guider_seq is not None:
660
+ guider_input_tokens[:, boi_idx - guider_index_delta] = tokenizer[
661
+ "<start_of_image>"
662
+ ]
663
+ boi_idx += 400
664
+
665
+ if strategy.is_done:
666
+ break
667
+ return strategy.finalize(tokens, mems)
668
+
669
+
670
+ class InferenceModel_Sequential(CogVideoCacheModel):
671
+ def __init__(self, args, transformer=None, parallel_output=True):
672
+ super().__init__(
673
+ args,
674
+ transformer=transformer,
675
+ parallel_output=parallel_output,
676
+ window_size=-1,
677
+ cogvideo_stage=1,
678
+ )
679
+
680
+ # TODO: check it
681
+
682
+ def final_forward(self, logits, **kwargs):
683
+ logits_parallel = logits
684
+ logits_parallel = torch.nn.functional.linear(
685
+ logits_parallel.float(),
686
+ self.transformer.word_embeddings.weight[:20000].float(),
687
+ )
688
+ return logits_parallel
689
+
690
+
691
+ class InferenceModel_Interpolate(CogVideoCacheModel):
692
+ def __init__(self, args, transformer=None, parallel_output=True):
693
+ super().__init__(
694
+ args,
695
+ transformer=transformer,
696
+ parallel_output=parallel_output,
697
+ window_size=10,
698
+ cogvideo_stage=2,
699
+ )
700
+
701
+ # TODO: check it
702
+
703
+ def final_forward(self, logits, **kwargs):
704
+ logits_parallel = logits
705
+ logits_parallel = torch.nn.functional.linear(
706
+ logits_parallel.float(),
707
+ self.transformer.word_embeddings.weight[:20000].float(),
708
+ )
709
+ return logits_parallel
710
+
711
+
712
+ def main(args):
713
+ assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1
714
+ rank_id = args.device % args.parallel_size
715
+ generate_frame_num = args.generate_frame_num
716
+
717
+ if args.stage_1 or args.both_stages:
718
+ model_stage1, args = InferenceModel_Sequential.from_pretrained(
719
+ args, "cogvideo-stage1"
720
+ )
721
+ model_stage1.eval()
722
+ if args.both_stages:
723
+ model_stage1 = model_stage1.cpu()
724
+
725
+ if args.stage_2 or args.both_stages:
726
+ model_stage2, args = InferenceModel_Interpolate.from_pretrained(
727
+ args, "cogvideo-stage2"
728
+ )
729
+ model_stage2.eval()
730
+ if args.both_stages:
731
+ model_stage2 = model_stage2.cpu()
732
+
733
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
734
+ strategy_cogview2 = CoglmStrategy(invalid_slices, temperature=1.0, top_k=16)
735
+ strategy_cogvideo = CoglmStrategy(
736
+ invalid_slices,
737
+ temperature=args.temperature,
738
+ top_k=args.top_k,
739
+ temperature2=args.coglm_temperature2,
740
+ )
741
+ if not args.stage_1:
742
+ from sr_pipeline import DirectSuperResolution
743
+
744
+ dsr_path = auto_create(
745
+ "cogview2-dsr", path=None
746
+ ) # path=os.getenv('SAT_HOME', '~/.sat_models')
747
+ dsr = DirectSuperResolution(args, dsr_path, max_bz=12, onCUDA=False)
748
+
749
+ def process_stage2(
750
+ model,
751
+ seq_text,
752
+ duration,
753
+ video_raw_text=None,
754
+ video_guidance_text="视频",
755
+ parent_given_tokens=None,
756
+ conddir=None,
757
+ outputdir=None,
758
+ gpu_rank=0,
759
+ gpu_parallel_size=1,
760
+ ):
761
+ stage2_starttime = time.time()
762
+ use_guidance = args.use_guidance_stage2
763
+ if args.both_stages:
764
+ move_start_time = time.time()
765
+ logging.debug("moving stage-2 model to cuda")
766
+ model = model.cuda()
767
+ logging.debug(
768
+ "moving in stage-2 model takes time: {:.2f}".format(
769
+ time.time() - move_start_time
770
+ )
771
+ )
772
+
773
+ try:
774
+ if parent_given_tokens is None:
775
+ assert conddir is not None
776
+ parent_given_tokens = torch.load(
777
+ os.path.join(conddir, "frame_tokens.pt"), map_location="cpu"
778
+ )
779
+ sample_num_allgpu = parent_given_tokens.shape[0]
780
+ sample_num = sample_num_allgpu // gpu_parallel_size
781
+ assert sample_num * gpu_parallel_size == sample_num_allgpu
782
+ parent_given_tokens = parent_given_tokens[
783
+ gpu_rank * sample_num : (gpu_rank + 1) * sample_num
784
+ ]
785
+ except:
786
+ logging.critical("No frame_tokens found in interpolation, skip")
787
+ return False
788
+
789
+ # CogVideo Stage2 Generation
790
+ while (
791
+ duration >= 0.5
792
+ ): # TODO: You can change the boundary to change the frame rate
793
+ parent_given_tokens_num = parent_given_tokens.shape[1]
794
+ generate_batchsize_persample = (parent_given_tokens_num - 1) // 2
795
+ generate_batchsize_total = generate_batchsize_persample * sample_num
796
+ total_frames = generate_frame_num
797
+ frame_len = 400
798
+ enc_text = tokenizer.encode(seq_text)
799
+ enc_duration = tokenizer.encode(str(float(duration)) + "秒")
800
+ seq = (
801
+ enc_duration
802
+ + [tokenizer["<n>"]]
803
+ + enc_text
804
+ + [tokenizer["<start_of_image>"]]
805
+ + [-1] * 400 * generate_frame_num
806
+ )
807
+ text_len = len(seq) - frame_len * generate_frame_num - 1
808
+
809
+ logging.info(
810
+ "[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(
811
+ int(4 / duration), tokenizer.decode(enc_text)
812
+ )
813
+ )
814
+
815
+ # generation
816
+ seq = (
817
+ torch.cuda.LongTensor(seq, device=args.device)
818
+ .unsqueeze(0)
819
+ .repeat(generate_batchsize_total, 1)
820
+ )
821
+ for sample_i in range(sample_num):
822
+ for i in range(generate_batchsize_persample):
823
+ seq[sample_i * generate_batchsize_persample + i][
824
+ text_len + 1 : text_len + 1 + 400
825
+ ] = parent_given_tokens[sample_i][2 * i]
826
+ seq[sample_i * generate_batchsize_persample + i][
827
+ text_len + 1 + 400 : text_len + 1 + 800
828
+ ] = parent_given_tokens[sample_i][2 * i + 1]
829
+ seq[sample_i * generate_batchsize_persample + i][
830
+ text_len + 1 + 800 : text_len + 1 + 1200
831
+ ] = parent_given_tokens[sample_i][2 * i + 2]
832
+
833
+ if use_guidance:
834
+ guider_seq = (
835
+ enc_duration
836
+ + [tokenizer["<n>"]]
837
+ + tokenizer.encode(video_guidance_text)
838
+ + [tokenizer["<start_of_image>"]]
839
+ + [-1] * 400 * generate_frame_num
840
+ )
841
+ guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
842
+ guider_seq = (
843
+ torch.cuda.LongTensor(guider_seq, device=args.device)
844
+ .unsqueeze(0)
845
+ .repeat(generate_batchsize_total, 1)
846
+ )
847
+ for sample_i in range(sample_num):
848
+ for i in range(generate_batchsize_persample):
849
+ guider_seq[sample_i * generate_batchsize_persample + i][
850
+ text_len + 1 : text_len + 1 + 400
851
+ ] = parent_given_tokens[sample_i][2 * i]
852
+ guider_seq[sample_i * generate_batchsize_persample + i][
853
+ text_len + 1 + 400 : text_len + 1 + 800
854
+ ] = parent_given_tokens[sample_i][2 * i + 1]
855
+ guider_seq[sample_i * generate_batchsize_persample + i][
856
+ text_len + 1 + 800 : text_len + 1 + 1200
857
+ ] = parent_given_tokens[sample_i][2 * i + 2]
858
+ video_log_text_attention_weights = 0
859
+ else:
860
+ guider_seq = None
861
+ guider_text_len = 0
862
+ video_log_text_attention_weights = 1.4
863
+
864
+ mbz = args.max_inference_batch_size
865
+
866
+ assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
867
+ output_list = []
868
+ start_time = time.time()
869
+ for tim in range(max(generate_batchsize_total // mbz, 1)):
870
+ input_seq = (
871
+ seq[: min(generate_batchsize_total, mbz)].clone()
872
+ if tim == 0
873
+ else seq[mbz * tim : mbz * (tim + 1)].clone()
874
+ )
875
+ guider_seq2 = (
876
+ (
877
+ guider_seq[: min(generate_batchsize_total, mbz)].clone()
878
+ if tim == 0
879
+ else guider_seq[mbz * tim : mbz * (tim + 1)].clone()
880
+ )
881
+ if guider_seq is not None
882
+ else None
883
+ )
884
+ output_list.append(
885
+ my_filling_sequence(
886
+ model,
887
+ args,
888
+ input_seq,
889
+ batch_size=min(generate_batchsize_total, mbz),
890
+ get_masks_and_position_ids=get_masks_and_position_ids_stage2,
891
+ text_len=text_len,
892
+ frame_len=frame_len,
893
+ strategy=strategy_cogview2,
894
+ strategy2=strategy_cogvideo,
895
+ log_text_attention_weights=video_log_text_attention_weights,
896
+ mode_stage1=False,
897
+ guider_seq=guider_seq2,
898
+ guider_text_len=guider_text_len,
899
+ guidance_alpha=args.guidance_alpha,
900
+ limited_spatial_channel_mem=True,
901
+ )[0]
902
+ )
903
+ logging.info(
904
+ "Duration {:.2f}, Taken time {:.2f}\n".format(
905
+ duration, time.time() - start_time
906
+ )
907
+ )
908
+
909
+ output_tokens = torch.cat(output_list, dim=0)
910
+ output_tokens = output_tokens[
911
+ :, text_len + 1 : text_len + 1 + (total_frames) * 400
912
+ ].reshape(sample_num, -1, 400 * total_frames)
913
+ output_tokens_merge = torch.cat(
914
+ (
915
+ output_tokens[:, :, : 1 * 400],
916
+ output_tokens[:, :, 400 * 3 : 4 * 400],
917
+ output_tokens[:, :, 400 * 1 : 2 * 400],
918
+ output_tokens[:, :, 400 * 4 : (total_frames) * 400],
919
+ ),
920
+ dim=2,
921
+ ).reshape(sample_num, -1, 400)
922
+
923
+ output_tokens_merge = torch.cat(
924
+ (output_tokens_merge, output_tokens[:, -1:, 400 * 2 : 3 * 400]), dim=1
925
+ )
926
+ duration /= 2
927
+ parent_given_tokens = output_tokens_merge
928
+
929
+ if args.both_stages:
930
+ move_start_time = time.time()
931
+ logging.debug("moving stage 2 model to cpu")
932
+ model = model.cpu()
933
+ torch.cuda.empty_cache()
934
+ logging.debug(
935
+ "moving out model2 takes time: {:.2f}".format(
936
+ time.time() - move_start_time
937
+ )
938
+ )
939
+
940
+ logging.info(
941
+ "CogVideo Stage2 completed. Taken time {:.2f}\n".format(
942
+ time.time() - stage2_starttime
943
+ )
944
+ )
945
+
946
+ # decoding
947
+ # imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
948
+ # os.makedirs(output_dir_full_path, exist_ok=True)
949
+ # my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False)
950
+ # torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt'))
951
+ # os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2")
952
+
953
+ # direct super-resolution by CogView2
954
+ logging.info("[Direct super-resolution]")
955
+ dsr_starttime = time.time()
956
+ enc_text = tokenizer.encode(seq_text)
957
+ frame_num_per_sample = parent_given_tokens.shape[1]
958
+ parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
959
+ text_seq = (
960
+ torch.cuda.LongTensor(enc_text, device=args.device)
961
+ .unsqueeze(0)
962
+ .repeat(parent_given_tokens_2d.shape[0], 1)
963
+ )
964
+ sred_tokens = dsr(text_seq, parent_given_tokens_2d)
965
+ decoded_sr_videos = []
966
+
967
+ for sample_i in range(sample_num):
968
+ decoded_sr_imgs = []
969
+ for frame_i in range(frame_num_per_sample):
970
+ decoded_sr_img = tokenizer.decode(
971
+ image_ids=sred_tokens[frame_i + sample_i * frame_num_per_sample][
972
+ -3600:
973
+ ]
974
+ )
975
+ decoded_sr_imgs.append(
976
+ torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480))
977
+ )
978
+ decoded_sr_videos.append(decoded_sr_imgs)
979
+
980
+ for sample_i in range(sample_num):
981
+ my_save_multiple_images(
982
+ decoded_sr_videos[sample_i],
983
+ outputdir,
984
+ subdir=f"frames/{sample_i+sample_num*gpu_rank}",
985
+ debug=False,
986
+ )
987
+ os.system(
988
+ f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125"
989
+ )
990
+
991
+ logging.info(
992
+ "Direct super-resolution completed. Taken time {:.2f}\n".format(
993
+ time.time() - dsr_starttime
994
+ )
995
+ )
996
+
997
+ return True
998
+
999
+ def process_stage1(
1000
+ model,
1001
+ seq_text,
1002
+ duration,
1003
+ video_raw_text=None,
1004
+ video_guidance_text="视频",
1005
+ image_text_suffix="",
1006
+ outputdir=None,
1007
+ batch_size=1,
1008
+ ):
1009
+ process_start_time = time.time()
1010
+ use_guide = args.use_guidance_stage1
1011
+ if args.both_stages:
1012
+ move_start_time = time.time()
1013
+ logging.debug("moving stage 1 model to cuda")
1014
+ model = model.cuda()
1015
+ logging.debug(
1016
+ "moving in model1 takes time: {:.2f}".format(
1017
+ time.time() - move_start_time
1018
+ )
1019
+ )
1020
+
1021
+ if video_raw_text is None:
1022
+ video_raw_text = seq_text
1023
+ mbz = (
1024
+ args.stage1_max_inference_batch_size
1025
+ if args.stage1_max_inference_batch_size > 0
1026
+ else args.max_inference_batch_size
1027
+ )
1028
+ assert batch_size < mbz or batch_size % mbz == 0
1029
+ frame_len = 400
1030
+
1031
+ # generate the first frame:
1032
+ enc_text = tokenizer.encode(seq_text + image_text_suffix)
1033
+ seq_1st = (
1034
+ enc_text + [tokenizer["<start_of_image>"]] + [-1] * 400
1035
+ ) # IV!! # test local!!! # test randboi!!!
1036
+ logging.info(
1037
+ "[Generating First Frame with CogView2]Raw text: {:s}".format(
1038
+ tokenizer.decode(enc_text)
1039
+ )
1040
+ )
1041
+ text_len_1st = len(seq_1st) - frame_len * 1 - 1
1042
+
1043
+ seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
1044
+ output_list_1st = []
1045
+ for tim in range(max(batch_size // mbz, 1)):
1046
+ start_time = time.time()
1047
+ output_list_1st.append(
1048
+ my_filling_sequence(
1049
+ model,
1050
+ args,
1051
+ seq_1st.clone(),
1052
+ batch_size=min(batch_size, mbz),
1053
+ get_masks_and_position_ids=get_masks_and_position_ids_stage1,
1054
+ text_len=text_len_1st,
1055
+ frame_len=frame_len,
1056
+ strategy=strategy_cogview2,
1057
+ strategy2=strategy_cogvideo,
1058
+ log_text_attention_weights=1.4,
1059
+ enforce_no_swin=True,
1060
+ mode_stage1=True,
1061
+ )[0]
1062
+ )
1063
+ logging.info(
1064
+ "[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)
1065
+ )
1066
+ output_tokens_1st = torch.cat(output_list_1st, dim=0)
1067
+ given_tokens = output_tokens_1st[
1068
+ :, text_len_1st + 1 : text_len_1st + 401
1069
+ ].unsqueeze(
1070
+ 1
1071
+ ) # given_tokens.shape: [bs, frame_num, 400]
1072
+
1073
+ # generate subsequent frames:
1074
+ total_frames = generate_frame_num
1075
+ enc_duration = tokenizer.encode(str(float(duration)) + "秒")
1076
+ if use_guide:
1077
+ video_raw_text = video_raw_text + " 视频"
1078
+ enc_text_video = tokenizer.encode(video_raw_text)
1079
+ seq = (
1080
+ enc_duration
1081
+ + [tokenizer["<n>"]]
1082
+ + enc_text_video
1083
+ + [tokenizer["<start_of_image>"]]
1084
+ + [-1] * 400 * generate_frame_num
1085
+ )
1086
+ guider_seq = (
1087
+ enc_duration
1088
+ + [tokenizer["<n>"]]
1089
+ + tokenizer.encode(video_guidance_text)
1090
+ + [tokenizer["<start_of_image>"]]
1091
+ + [-1] * 400 * generate_frame_num
1092
+ )
1093
+ logging.info(
1094
+ "[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(
1095
+ 4 / duration, tokenizer.decode(enc_text_video)
1096
+ )
1097
+ )
1098
+
1099
+ text_len = len(seq) - frame_len * generate_frame_num - 1
1100
+ guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
1101
+ seq = (
1102
+ torch.cuda.LongTensor(seq, device=args.device)
1103
+ .unsqueeze(0)
1104
+ .repeat(batch_size, 1)
1105
+ )
1106
+ guider_seq = (
1107
+ torch.cuda.LongTensor(guider_seq, device=args.device)
1108
+ .unsqueeze(0)
1109
+ .repeat(batch_size, 1)
1110
+ )
1111
+
1112
+ for given_frame_id in range(given_tokens.shape[1]):
1113
+ seq[
1114
+ :,
1115
+ text_len
1116
+ + 1
1117
+ + given_frame_id * 400 : text_len
1118
+ + 1
1119
+ + (given_frame_id + 1) * 400,
1120
+ ] = given_tokens[:, given_frame_id]
1121
+ guider_seq[
1122
+ :,
1123
+ guider_text_len
1124
+ + 1
1125
+ + given_frame_id * 400 : guider_text_len
1126
+ + 1
1127
+ + (given_frame_id + 1) * 400,
1128
+ ] = given_tokens[:, given_frame_id]
1129
+ output_list = []
1130
+
1131
+ if use_guide:
1132
+ video_log_text_attention_weights = 0
1133
+ else:
1134
+ guider_seq = None
1135
+ video_log_text_attention_weights = 1.4
1136
+
1137
+ for tim in range(max(batch_size // mbz, 1)):
1138
+ start_time = time.time()
1139
+ input_seq = (
1140
+ seq[: min(batch_size, mbz)].clone()
1141
+ if tim == 0
1142
+ else seq[mbz * tim : mbz * (tim + 1)].clone()
1143
+ )
1144
+ guider_seq2 = (
1145
+ (
1146
+ guider_seq[: min(batch_size, mbz)].clone()
1147
+ if tim == 0
1148
+ else guider_seq[mbz * tim : mbz * (tim + 1)].clone()
1149
+ )
1150
+ if guider_seq is not None
1151
+ else None
1152
+ )
1153
+ output_list.append(
1154
+ my_filling_sequence(
1155
+ model,
1156
+ args,
1157
+ input_seq,
1158
+ batch_size=min(batch_size, mbz),
1159
+ get_masks_and_position_ids=get_masks_and_position_ids_stage1,
1160
+ text_len=text_len,
1161
+ frame_len=frame_len,
1162
+ strategy=strategy_cogview2,
1163
+ strategy2=strategy_cogvideo,
1164
+ log_text_attention_weights=video_log_text_attention_weights,
1165
+ guider_seq=guider_seq2,
1166
+ guider_text_len=guider_text_len,
1167
+ guidance_alpha=args.guidance_alpha,
1168
+ limited_spatial_channel_mem=True,
1169
+ mode_stage1=True,
1170
+ )[0]
1171
+ )
1172
+
1173
+ output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len :]
1174
+
1175
+ if args.both_stages:
1176
+ move_start_time = time.time()
1177
+ logging.debug("moving stage 1 model to cpu")
1178
+ model = model.cpu()
1179
+ torch.cuda.empty_cache()
1180
+ logging.debug(
1181
+ "moving in model1 takes time: {:.2f}".format(
1182
+ time.time() - move_start_time
1183
+ )
1184
+ )
1185
+
1186
+ # decoding
1187
+ imgs, sred_imgs, txts = [], [], []
1188
+ for seq in output_tokens:
1189
+ decoded_imgs = [
1190
+ torch.nn.functional.interpolate(
1191
+ tokenizer.decode(image_ids=seq.tolist()[i * 400 : (i + 1) * 400]),
1192
+ size=(480, 480),
1193
+ )
1194
+ for i in range(total_frames)
1195
+ ]
1196
+ imgs.append(decoded_imgs) # only the last image (target)
1197
+
1198
+ assert len(imgs) == batch_size
1199
+ save_tokens = (
1200
+ output_tokens[:, : +total_frames * 400].reshape(-1, total_frames, 400).cpu()
1201
+ )
1202
+ if outputdir is not None:
1203
+ for clip_i in range(len(imgs)):
1204
+ # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True)
1205
+ my_save_multiple_images(
1206
+ imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False
1207
+ )
1208
+ os.system(
1209
+ f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25"
1210
+ )
1211
+ torch.save(save_tokens, os.path.join(outputdir, "frame_tokens.pt"))
1212
+
1213
+ logging.info(
1214
+ "CogVideo Stage1 completed. Taken time {:.2f}\n".format(
1215
+ time.time() - process_start_time
1216
+ )
1217
+ )
1218
+
1219
+ return save_tokens
1220
+
1221
+ # ======================================================================================================
1222
+
1223
+ if args.stage_1 or args.both_stages:
1224
+ if args.input_source != "interactive":
1225
+ with open(args.input_source, "r") as fin:
1226
+ promptlist = fin.readlines()
1227
+ promptlist = [p.strip() for p in promptlist]
1228
+ else:
1229
+ promptlist = None
1230
+
1231
+ now_qi = -1
1232
+ while True:
1233
+ now_qi += 1
1234
+
1235
+ if promptlist is not None: # with input-source
1236
+ if args.multi_gpu:
1237
+ if now_qi % dist.get_world_size() != dist.get_rank():
1238
+ continue
1239
+ rk = dist.get_rank()
1240
+ else:
1241
+ rk = 0
1242
+ raw_text = promptlist[now_qi]
1243
+ raw_text = raw_text.strip()
1244
+ print(f"Working on Line No. {now_qi} on {rk}... [{raw_text}]")
1245
+ else: # interactive
1246
+ raw_text = input("\nPlease Input Query (stop to exit) >>> ")
1247
+ raw_text = raw_text.strip()
1248
+ if not raw_text:
1249
+ print("Query should not be empty!")
1250
+ continue
1251
+ if raw_text == "stop":
1252
+ return
1253
+
1254
+ try:
1255
+ path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
1256
+ parent_given_tokens = process_stage1(
1257
+ model_stage1,
1258
+ raw_text,
1259
+ duration=4.0,
1260
+ video_raw_text=raw_text,
1261
+ video_guidance_text="视频",
1262
+ image_text_suffix=" 高清摄影",
1263
+ outputdir=path if args.stage_1 else None,
1264
+ batch_size=args.batch_size,
1265
+ )
1266
+ if args.both_stages:
1267
+ process_stage2(
1268
+ model_stage2,
1269
+ raw_text,
1270
+ duration=2.0,
1271
+ video_raw_text=raw_text + " 视频",
1272
+ video_guidance_text="视频",
1273
+ parent_given_tokens=parent_given_tokens,
1274
+ outputdir=path,
1275
+ gpu_rank=0,
1276
+ gpu_parallel_size=1,
1277
+ ) # TODO: 修改
1278
+ except (ValueError, FileNotFoundError) as e:
1279
+ print(e)
1280
+ continue
1281
+
1282
+ elif args.stage_2:
1283
+ sample_dirs = os.listdir(args.output_path)
1284
+ for sample in sample_dirs:
1285
+ raw_text = sample.split("_")[-1]
1286
+ path = os.path.join(args.output_path, sample, "Interp")
1287
+ parent_given_tokens = torch.load(
1288
+ os.path.join(args.output_path, sample, "frame_tokens.pt")
1289
+ )
1290
+
1291
+ process_stage2(
1292
+ raw_text,
1293
+ duration=2.0,
1294
+ video_raw_text=raw_text + " 视频",
1295
+ video_guidance_text="视频",
1296
+ parent_given_tokens=parent_given_tokens,
1297
+ outputdir=path,
1298
+ gpu_rank=0,
1299
+ gpu_parallel_size=1,
1300
+ ) # TODO: 修改
1301
+
1302
+ else:
1303
+ assert False
1304
+
1305
+
1306
+ if __name__ == "__main__":
1307
+ logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
1308
+
1309
+ py_parser = argparse.ArgumentParser(add_help=False)
1310
+ py_parser.add_argument("--generate-frame-num", type=int, default=5)
1311
+ py_parser.add_argument("--coglm-temperature2", type=float, default=0.89)
1312
+ # py_parser.add_argument("--interp-duration", type=float, default=-1) # -1是顺序生成,0是超分,0.5/1/2是插帧
1313
+ # py_parser.add_argument("--total-duration", type=float, default=4.0) # 整个的时间
1314
+ py_parser.add_argument("--use-guidance-stage1", action="store_true")
1315
+ py_parser.add_argument("--use-guidance-stage2", action="store_true")
1316
+ py_parser.add_argument("--guidance-alpha", type=float, default=3.0)
1317
+ py_parser.add_argument(
1318
+ "--stage-1", action="store_true"
1319
+ ) # stage 1: sequential generation
1320
+ py_parser.add_argument("--stage-2", action="store_true") # stage 2: interp + dsr
1321
+ py_parser.add_argument(
1322
+ "--both-stages", action="store_true"
1323
+ ) # stage 1&2: sequential generation; interp + dsr
1324
+ py_parser.add_argument("--parallel-size", type=int, default=1)
1325
+ py_parser.add_argument(
1326
+ "--stage1-max-inference-batch-size", type=int, default=-1
1327
+ ) # -1: use max-inference-batch-size
1328
+ py_parser.add_argument("--multi-gpu", action="store_true")
1329
+
1330
+ CogVideoCacheModel.add_model_specific_args(py_parser)
1331
+
1332
+ known, args_list = py_parser.parse_known_args()
1333
+ args = get_args(args_list)
1334
+ args = argparse.Namespace(**vars(args), **vars(known))
1335
+ args.layout = [int(x) for x in args.layout.split(",")]
1336
+ args.do_train = False
1337
+
1338
+ torch.cuda.set_device(args.device)
1339
+
1340
+ with torch.no_grad():
1341
+ main(args)
src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_cache_model.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : cogvideo_cache_model.py
4
+ @Time : 2022/07/15 11:22:19
5
+ @Author : Wenyi Hong
6
+ @Version : 1.0
7
+ @Contact : [email protected]
8
+ '''
9
+
10
+ # here put the import lib
11
+
12
+ from multiprocessing import context
13
+ from tkinter import E
14
+ import torch
15
+ from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
16
+
17
+ from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
18
+ from SwissArmyTransformer.model.transformer import unscaled_init_method
19
+ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
20
+ import torch.nn.functional as F
21
+ from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
22
+ import math
23
+
24
+
25
+ class PositionEmbeddingMixin(BaseMixin):
26
+ def __init__(self, additional_sequence_length, hidden_size,
27
+ init_method_std=0.02, reinit_slice=slice(512, 912),
28
+ ):
29
+ super(PositionEmbeddingMixin, self).__init__()
30
+ self.reinit_slice = reinit_slice
31
+ self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
+ torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
+
34
+ def reinit(self, parent_model=None):
35
+ old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
+ old_len, hidden_size = old_weights.shape
37
+ assert hidden_size == self.position_embeddings.weight.shape[-1]
38
+ self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
39
+
40
+
41
+ def window_partition(x, window_size):
42
+ """
43
+ Args:
44
+ x: (B, framenum, H, W, C)
45
+ window_size (int): window size
46
+ Returns:
47
+ windows: (num_windows*B, frame_num, window_size, window_size, C)
48
+ """
49
+ B, framenum, H, W, C = x.shape
50
+ x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
51
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
52
+ return windows
53
+
54
+ def window_reverse(windows, window_size, H, W):
55
+ """
56
+ Args:
57
+ windows: (num_windows*B, frame_num, window_size, window_size, C)
58
+ window_size (int): Window size
59
+ H (int): Height of image
60
+ W (int): Width of image
61
+ Returns:
62
+ x: (B, frame_num, H, W, C)
63
+ """
64
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
65
+ framenum = windows.shape[1]
66
+ x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
67
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
68
+ return x
69
+
70
+ class WindowAttentionMixin(BaseMixin):
71
+ def __init__(self, num_layers,
72
+ hidden_size,
73
+ frame_resolution,
74
+ window_size,
75
+ shift_size,
76
+ n_head,
77
+ frame_num,
78
+ init_method=unscaled_init_method(0.02),
79
+ output_layer_init_method=unscaled_init_method(0.02),
80
+ time_dim_attend_length=0
81
+ ):
82
+ super(WindowAttentionMixin, self).__init__()
83
+ self.num_layers = num_layers # replace attention in the LAST n layers
84
+ self.query_key_value = torch.nn.ModuleList(
85
+ [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
86
+ gather_output=False,init_method=init_method)
87
+ for layer_id in range(num_layers)
88
+ ])
89
+ self.dense = torch.nn.ModuleList(
90
+ [RowParallelLinear(
91
+ hidden_size,
92
+ hidden_size,
93
+ input_is_parallel=True,
94
+ init_method=output_layer_init_method,
95
+ bias=True,
96
+ module=self,
97
+ name="dense")
98
+ for layer_id in range(num_layers)
99
+ ])
100
+
101
+ self.n_head = n_head
102
+ self.window_size = window_size
103
+ self.frame_resolution = frame_resolution
104
+ self.frame_len = frame_resolution * frame_resolution
105
+ self.time_dim_attend_length = time_dim_attend_length
106
+ assert frame_resolution % window_size == 0
107
+ assert 0 < shift_size < window_size
108
+ nW = (self.frame_resolution // self.window_size) ** 2
109
+ ws_squre = self.window_size * self.window_size
110
+
111
+ # odd non-shift, even shift
112
+ img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
113
+ h_slices = (slice(0, -shift_size),
114
+ slice(-shift_size, None))
115
+ w_slices = (slice(0, -shift_size),
116
+ slice(-shift_size, None))
117
+ cnt = 0
118
+ for h in h_slices:
119
+ for w in w_slices:
120
+ img_mask[:, :, h, w, :] = cnt
121
+ cnt += 1
122
+ mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
123
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
124
+ sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
125
+ sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
126
+ attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
127
+ attn_mask = attn_mask.tril()
128
+
129
+ causal_mask = torch.ones(ws_squre*frame_num, ws_squre*frame_num)
130
+ causal_mask = causal_mask.tril()
131
+
132
+ self.shift_sizes = [0, shift_size]
133
+ self.attn_mask = attn_mask
134
+ self.causal_mask = causal_mask
135
+ self.mask_initialized = False
136
+
137
+ self.attn_distribution = torch.nn.ParameterList([
138
+ torch.nn.Parameter(torch.zeros(hidden_size))
139
+ for _ in range(num_layers)
140
+ ])
141
+
142
+ def reinit(self, *pre_mixins):
143
+ start_layer = len(self.transformer.layers) - self.num_layers
144
+ assert start_layer >= 0
145
+ for layer_id in range(self.num_layers):
146
+ old_attention = self.transformer.layers[start_layer + layer_id].attention
147
+ self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
148
+ self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
149
+
150
+ def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
151
+ # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
152
+ if not self.mask_initialized:
153
+ self.attn_mask = self.attn_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
154
+ self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
155
+ self.mask_initialized = True
156
+ b0, s1, h0 = frame_hidden_state.shape
157
+ h = h0 // self.n_head
158
+ frame_len = self.frame_resolution * self.frame_resolution
159
+ frame_num = s1 // frame_len
160
+ if stage == 2:
161
+ assert frame_num == 3
162
+ assert frame_num*frame_len == s1
163
+ wind_square = self.window_size * self.window_size
164
+ nW = frame_len // wind_square
165
+ bswin = b0 * nW
166
+
167
+ if memkv_text is not None:
168
+ s0 = memkv_text.shape[-2]
169
+ k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
170
+ v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
171
+
172
+ # shift
173
+ frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
174
+ if self.shift_sizes[layer_id%2] > 0:
175
+ frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
176
+ # window partition
177
+ frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
178
+ qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
179
+ .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
180
+ q, k, v = qkv[0], qkv[1], qkv[2]
181
+ attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
182
+
183
+ if stage == 1:
184
+ if self.shift_sizes[layer_id%2] > 0:
185
+ attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square),
186
+ self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))\
187
+ - 10000.0 * (1.0 - self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))
188
+ attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
189
+ else:
190
+ attn = torch.mul(attn, self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))\
191
+ - 10000.0 * (1.0 - self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))
192
+
193
+ if memkv_text is None:
194
+ attn = F.softmax(attn, dim=-1)
195
+ if attn_dropout is not None:
196
+ with get_cuda_rng_tracker().fork():
197
+ attn = attn_dropout(attn)
198
+ context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
199
+ else:
200
+ attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
201
+ attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
202
+ attn = torch.cat((attn, attn_frame2text), dim=-1)
203
+ attn = F.softmax(attn, dim=-1)
204
+
205
+ if attn_dropout is not None:
206
+ with get_cuda_rng_tracker().fork():
207
+ attn = attn_dropout(attn)
208
+
209
+ context_swin = (torch.matmul(attn[..., :-s0], v) +
210
+ torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
211
+ .reshape(bswin, self.n_head, frame_num*wind_square, h))\
212
+ .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
213
+
214
+ context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
215
+
216
+ # reverse cycle shift
217
+ if self.shift_sizes[layer_id%2] > 0:
218
+ context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
219
+ ret_context = context_swin.reshape(b0, s1, h0)
220
+
221
+ # for mem
222
+ memk = k.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
223
+ memv = v.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
224
+ memk = window_reverse(memk, self.window_size, self.frame_resolution, self.frame_resolution)
225
+ memv = window_reverse(memv, self.window_size, self.frame_resolution, self.frame_resolution)
226
+ if self.shift_sizes[layer_id%2] > 0:
227
+ memk = torch.roll(memk, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
228
+ memv = torch.roll(memv, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
229
+ memk, memv = memk.reshape(b0, s1, h0), memv.reshape(b0, s1, h0)
230
+
231
+ ret_mem = torch.cat((memk, memv), dim=-1)
232
+ return ret_context, ret_mem
233
+
234
+ def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
235
+ # frame_hidden_state [batchsize, 1, n_head*hiddensize_perhead]
236
+ # memkv [batchsize, pos, hidden_size*2] (include frames only)
237
+ # if memkv_text is not None: will attend to text
238
+ # pos: token's pos
239
+ b0, sin, h0 = frame_hidden_state.shape
240
+ h = h0 // self.n_head
241
+ assert sin == 1
242
+ this_qkv = self.query_key_value[layer_id](frame_hidden_state)
243
+ thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
244
+ s1 = memkv.shape[1] if memkv is not None else 0
245
+ frame_len = self.frame_resolution * self.frame_resolution
246
+ frame_num_before = s1 // frame_len
247
+
248
+
249
+ if memkv is not None:
250
+ pos_inframe = pos - frame_num_before * frame_len
251
+
252
+ xpos = pos_inframe // self.frame_resolution # pos = xpos*self.frame_resolution + ypos
253
+ ypos = pos_inframe % self.frame_resolution
254
+ # [start, end)
255
+ if self.shift_sizes[layer_id%2] > 0:
256
+ xstart = ((xpos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
257
+ ystart = ((ypos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
258
+ xend = xstart + self.window_size
259
+ yend = ystart + self.window_size
260
+ xstart, ystart = max(0, xstart), max(0, ystart)
261
+ xend, yend = min(xend, self.frame_resolution), min(yend, self.frame_resolution)
262
+ else:
263
+ xstart = (xpos // self.window_size) * self.window_size
264
+ ystart = (ypos // self.window_size) * self.window_size
265
+ xend, yend = xstart + self.window_size, ystart+self.window_size
266
+
267
+ # select index
268
+ selected_index = list()
269
+ if frame_num_before > 0:
270
+ # frames before
271
+ frame_attended_start = max(0, frame_num_before-self.time_dim_attend_length+1) if self.time_dim_attend_length > 0 else 0
272
+ for x in range(xstart, xend):
273
+ for y in range(ystart, yend):
274
+ selected_index.append(x*self.frame_resolution+y+frame_len*frame_attended_start)
275
+ cnt_per_frame = len(selected_index)
276
+ for _ in range((frame_num_before-frame_attended_start-1)*cnt_per_frame):
277
+ selected_index.append(selected_index[-cnt_per_frame]+frame_len)
278
+
279
+ # the last frame
280
+ for x in range(xstart, xend):
281
+ for y in range(ystart, yend):
282
+ tmppos = x*self.frame_resolution+y + frame_num_before * frame_len
283
+ if tmppos < pos:
284
+ selected_index.append(tmppos)
285
+ else:
286
+ break
287
+ cnt_all = len(selected_index)+1
288
+ selected_index = torch.tensor(selected_index, device=memkv.device)
289
+ used_memkv = torch.index_select(memkv, 1, selected_index)
290
+ used_k, used_v = used_memkv[..., :h0], used_memkv[..., h0:]
291
+ used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
292
+ used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
293
+ if memkv_text is not None:
294
+ cnt_all += memkv_text.shape[-2]
295
+ used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
296
+ used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
297
+ used_k = used_k.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
298
+ used_v = used_v.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
299
+ else:
300
+ used_k = thisk
301
+ used_v = thisv
302
+
303
+ if memkv_text is not None:
304
+ used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
305
+ used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
306
+ used_k = used_k.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
307
+ used_v = used_v.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
308
+ else:
309
+ used_k = used_k.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
310
+ used_v = used_v.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
311
+
312
+ thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
313
+ attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
314
+ if memkv_text is not None:
315
+ attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
316
+ attn = F.softmax(attn, dim=-1)
317
+ context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
318
+
319
+ return context_swin, this_qkv[..., h0:]
320
+
321
+ class FullAttentionMixin(BaseMixin):
322
+ def __init__(self, num_layers,
323
+ hidden_size,
324
+ frame_resolution,
325
+ n_head,
326
+ frame_num,
327
+ init_method=unscaled_init_method(0.02),
328
+ output_layer_init_method=unscaled_init_method(0.02),
329
+ **kwargs,
330
+ ):
331
+ super(FullAttentionMixin, self).__init__()
332
+ self.num_layers = num_layers # replace attention in the LAST n layers
333
+ self.query_key_value = torch.nn.ModuleList(
334
+ [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
335
+ gather_output=False,init_method=init_method)
336
+ for layer_id in range(num_layers)
337
+ ])
338
+ self.dense = torch.nn.ModuleList(
339
+ [RowParallelLinear(
340
+ hidden_size,
341
+ hidden_size,
342
+ input_is_parallel=True,
343
+ init_method=output_layer_init_method,
344
+ bias=True,
345
+ module=self,
346
+ name="dense")
347
+ for layer_id in range(num_layers)
348
+ ])
349
+
350
+ self.n_head = n_head
351
+ self.frame_resolution = frame_resolution
352
+ self.frame_len = frame_resolution * frame_resolution
353
+
354
+ self.attn_distribution = torch.nn.ParameterList([
355
+ torch.nn.Parameter(torch.zeros(hidden_size))
356
+ for _ in range(num_layers)
357
+ ])
358
+
359
+ def reinit(self, *pre_mixins):
360
+ start_layer = len(self.transformer.layers) - self.num_layers
361
+ assert start_layer >= 0
362
+ for layer_id in range(self.num_layers):
363
+ old_attention = self.transformer.layers[start_layer + layer_id].attention
364
+ self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
365
+ self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
366
+
367
+
368
+ def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
369
+ # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
370
+ assert stage == 1
371
+
372
+ b0, s1, h0 = frame_hidden_state.shape
373
+ h = h0 // self.n_head
374
+ frame_len = self.frame_resolution * self.frame_resolution
375
+ frame_num = s1 // frame_len
376
+ assert frame_num*frame_len == s1
377
+
378
+ if memkv_text is not None:
379
+ s0 = memkv_text.shape[-2]
380
+ k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
381
+ v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
382
+ qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
383
+ .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
384
+ q, k, v = qkv[0], qkv[1], qkv[2]
385
+ attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
386
+ attn = attn - 10000.0 * (1.0-torch.ones(b0, self.n_head, s1, s1, device=attn.device, dtype=attn.dtype).tril())
387
+
388
+ if memkv_text is None:
389
+ attn = F.softmax(attn, dim=-1)
390
+ if attn_dropout is not None:
391
+ with get_cuda_rng_tracker().fork():
392
+ attn = attn_dropout(attn)
393
+ context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
394
+ else:
395
+ attn_frame2text = torch.matmul(q / math.sqrt(h), k_text.transpose(-1, -2)) #[b0, s1, s0]
396
+ attn = torch.cat((attn, attn_frame2text), dim=-1)
397
+ attn = F.softmax(attn, dim=-1)
398
+ if attn_dropout is not None:
399
+ with get_cuda_rng_tracker().fork():
400
+ attn = attn_dropout(attn)
401
+ context_swin = (torch.matmul(attn[..., :-s0], v) + torch.matmul(attn[..., -s0:], v_text))\
402
+ .permute(0, 2, 1, 3).reshape(b0, s1, h0)
403
+
404
+ # for mem
405
+ memk = k.permute(0, 2, 1, 3).reshape(b0, s1, h0)
406
+ memv = v.permute(0, 2, 1, 3).reshape(b0, s1, h0)
407
+ ret_mem = torch.cat((memk, memv), dim=-1)
408
+
409
+ return context_swin, ret_mem
410
+
411
+ def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
412
+ # pos: current token's pos
413
+ b0, sin, h0 = frame_hidden_state.shape
414
+ h = h0 // self.n_head
415
+ assert sin == 1
416
+ assert stage == 1
417
+
418
+ this_qkv = self.query_key_value[layer_id](frame_hidden_state)
419
+ thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
420
+
421
+ if memkv is not None:
422
+ used_k, used_v = memkv[..., :h0], memkv[..., h0:]
423
+ used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
424
+ used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
425
+ else:
426
+ used_k, used_v = thisk, thisv
427
+
428
+ if memkv_text is not None:
429
+ used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
430
+ used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
431
+
432
+ used_k = used_k.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
433
+ used_v = used_v.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
434
+ thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
435
+ attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
436
+ if memkv_text is not None:
437
+ attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
438
+ attn = F.softmax(attn, dim=-1)
439
+
440
+ context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
441
+
442
+ return context_swin, this_qkv[..., h0:]
443
+
444
+
445
+ def attention_localframe_and_text_NAR(q0, k0, v0, attention_mask,
446
+ n_head, text_len, frame_len, frame_num,
447
+ attention_dropout=None, log_text_attention_weights=0, stage=1, **kwargs):
448
+ b, s0, h0 = q0.shape
449
+ s1 = s0 - text_len
450
+ h = h0 // n_head
451
+ assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
452
+ # attention_mask.shape [4, b or 1, 1, text_len+frame_len, text_len+frame_len]
453
+ if stage == 2:
454
+ assert frame_num == 3
455
+
456
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
457
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
458
+ k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
459
+ k0T = k0.transpose(-1, -2)
460
+
461
+ score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
462
+ score_any2text += log_text_attention_weights
463
+ score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask[..., :text_len, :text_len]) \
464
+ - 10000.0 * (1.0 - attention_mask[..., :text_len, :text_len])
465
+ # context for text
466
+ attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
467
+ if attention_dropout is not None:
468
+ with get_cuda_rng_tracker().fork():
469
+ attention_probs_text = attention_dropout(attention_probs_text)
470
+ context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
471
+ context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
472
+
473
+ if frame_num > 0:
474
+ score_any2text_part2 = score_any2text[..., text_len:, :]
475
+
476
+ # score: frame local
477
+ q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
478
+ v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
479
+ k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
480
+ score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
481
+ if stage == 1:
482
+ score_frame_local0 = torch.mul(score_frame_local0, attention_mask[..., text_len:, text_len:].unsqueeze(1)) \
483
+ - 10000.0 * (1.0 - attention_mask[..., text_len:, text_len:].unsqueeze(1))
484
+
485
+ # context for frame
486
+ score_frame_all = torch.cat((score_any2text_part2,
487
+ score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
488
+ attention_probs_frame = F.softmax(score_frame_all, dim=-1)
489
+ if attention_dropout is not None:
490
+ with get_cuda_rng_tracker().fork():
491
+ attention_probs_frame = attention_dropout(attention_probs_frame)
492
+ context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
493
+ context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
494
+ view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
495
+
496
+ context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
497
+ else:
498
+ context_frame = None
499
+
500
+ return context_text2text, context_frame
501
+
502
+ def attention_localframe_and_text_AR(q0, k0, v0, n_head, text_len, frame_len, frame_num,
503
+ attention_dropout=None, log_text_attention_weights=0, layer_id=None, limited_spatial_channel_mem=False, stage=1, **kwargs):
504
+ # limited_spatial_channel_mem=True means: mems in spatial channel is consisted of {mem_text, mem_current_frame}
505
+ b, s0, h0 = k0.shape
506
+ frame_num_before = (s0-text_len-1) // frame_len # frame_num == frame_num_before or frame_num == frame_num_before+1
507
+ h = h0 // n_head
508
+ assert q0.shape[1] == 1
509
+ assert v0.shape[1] == k0.shape[1]
510
+
511
+ q0 = q0.reshape(b, 1, n_head, h).permute(0, 2, 1, 3)
512
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
513
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
514
+
515
+ if limited_spatial_channel_mem:
516
+ assert frame_num_before == 0
517
+ assert stage == 1 # not implemented for stage-2 yet
518
+ score = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
519
+ score[..., :text_len] += log_text_attention_weights
520
+ attention_probs_frame = F.softmax(score, dim=-1)
521
+ context_frame = torch.matmul(attention_probs_frame, v0).transpose(1, 2).reshape(b, 1, h0)
522
+
523
+ else:
524
+ score_token2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
525
+ score_token2text += log_text_attention_weights
526
+ score_frame_local0 = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., text_len+frame_num_before*frame_len:])
527
+ score_frame_all = torch.cat((score_token2text,
528
+ score_frame_local0), dim=-1)
529
+ attention_probs_frame = F.softmax(score_frame_all, dim=-1)
530
+
531
+ context_token2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
532
+ context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:], \
533
+ v0[:, :, text_len+frame_num_before*frame_len:, :])
534
+ context_frame = (context_token2text + context_frame_local0).transpose(1, 2).reshape(b, 1, h0)
535
+
536
+ return context_frame
537
+
538
+
539
+ class CogVideoCacheModel(BaseModel):
540
+ def __init__(self, args, transformer=None, parallel_output=True, window_size=None, cogvideo_stage=None):
541
+ super().__init__(args, transformer=transformer, parallel_output=parallel_output)
542
+ self.layout = args.layout # [64, 64+1024, 64+6*1024]
543
+ self.stage = cogvideo_stage if cogvideo_stage is not None else args.cogvideo_stage # 1 or 2
544
+ self.n_head = args.num_attention_heads
545
+ self.window_size = window_size if window_size is not None else args.window_size
546
+
547
+ frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
548
+ self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
549
+ args.additional_seqlen, args.hidden_size
550
+ ))
551
+
552
+ if self.stage == 1:
553
+ self.add_mixin('attention_plus', FullAttentionMixin(
554
+ num_layers=args.num_layers,
555
+ hidden_size=args.hidden_size,
556
+ frame_resolution=frame_resolution,
557
+ n_head=args.num_attention_heads,
558
+ frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
559
+ ))
560
+ else:
561
+ self.add_mixin('attention_plus', WindowAttentionMixin(
562
+ num_layers=args.num_layers,
563
+ hidden_size=args.hidden_size,
564
+ frame_resolution=frame_resolution,
565
+ window_size=self.window_size,
566
+ shift_size=self.window_size//2,
567
+ n_head=args.num_attention_heads,
568
+ frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
569
+ ))
570
+
571
+
572
+ @classmethod
573
+ def add_model_specific_args(cls, parser):
574
+ group = parser.add_argument_group('VideoSwinLocalModel', 'video swin local model configurations')
575
+ group.add_argument("--layout", type=str, default='64, 464, 2064')
576
+ group.add_argument("--window-size", type=int, default=10) # 优先级在直接参数赋值之后
577
+ group.add_argument("--additional-seqlen", type=int, default=2000)
578
+ group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) # 优先级在直接参数赋值之后
579
+ return parser
580
+
581
+ def disable_untrainable_params(self):
582
+ pass
583
+
584
+ def position_embedding_forward(self, position_ids, **kw_args):
585
+ if position_ids.shape[-1] > 1:
586
+ if self.stage == 1:
587
+ if position_ids[0,-1] >= (512+400):
588
+ frame_num = position_ids.shape[-1] // 400
589
+ position_embeddings = torch.cat(
590
+ (
591
+ self.transformer.position_embeddings(position_ids[..., :-400*(frame_num-1)]),
592
+ self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -400*(frame_num-1):]-(512+400))
593
+ ),
594
+ dim=-2
595
+ )
596
+ else:
597
+ position_embeddings = self.transformer.position_embeddings(position_ids)
598
+ else:
599
+ # given 3, interpolate 2
600
+ position_embeddings = torch.cat(
601
+ (
602
+ self.transformer.position_embeddings(position_ids[..., :-800]),
603
+ self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -800:]-(512+400))
604
+ ),
605
+ dim=-2
606
+ )
607
+ else:
608
+ if position_ids[0, 0] >= (512+400):
609
+ position_embeddings = self.get_mixin('extra_position_embedding').position_embeddings(position_ids-(512+400))
610
+ else:
611
+ position_embeddings = self.transformer.position_embeddings(position_ids)
612
+ return position_embeddings
613
+
614
+ def attention_forward(self, hidden_states, mask, layer_id, mems=None, log_text_attention_weights=0, text_len=0, frame_len=0, counter=0, enforce_no_swin=False, limited_spatial_channel_mem=False, **kw_args):
615
+ attn_module = self.transformer.layers[layer_id].attention
616
+ hidden_size = hidden_states.shape[-1]
617
+
618
+ # base model qkv
619
+ if mems is None:
620
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
621
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
622
+ assert (q0.shape[1]-text_len) % frame_len == 0
623
+ memkv0 = torch.cat((k0, v0), dim=-1)
624
+ context_text, context_frame_local_text = attention_localframe_and_text_NAR(
625
+ q0, k0, v0,
626
+ mask,
627
+ n_head=attn_module.num_attention_heads_per_partition,
628
+ text_len=text_len,
629
+ frame_len=frame_len,
630
+ frame_num=(q0.shape[1]-text_len)//frame_len,
631
+ log_text_attention_weights=log_text_attention_weights,
632
+ stage=self.stage
633
+ )
634
+
635
+ # change: self.swin_attend_to_text默认为True:
636
+ memkv1_text = self.get_mixin('attention_plus').query_key_value[layer_id](hidden_states[..., :text_len, :])[..., hidden_size:]
637
+ output_text = attn_module.dense(context_text)
638
+
639
+ if (q0.shape[1]-text_len)//frame_len > 0:
640
+ assert (q0.shape[1]-text_len) % frame_len == 0
641
+ context_frame_swin, memkv1_frame = self.get_mixin('attention_plus').attention_extra_NAR_inference(
642
+ hidden_states[:,text_len:], layer_id, memkv_text=memkv1_text, stage=self.stage)
643
+ if not enforce_no_swin:
644
+ attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
645
+ attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
646
+ output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
647
+ +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
648
+ else:
649
+ output_frame = attn_module.dense(context_frame_local_text[..., :frame_len, :])
650
+ output = torch.cat((output_text, output_frame), dim=-2)
651
+ memkv1 = torch.cat((memkv1_text, memkv1_frame), dim=-2) if memkv1_text is not None else memkv1_frame
652
+ else:
653
+ output = output_text
654
+ memkv1 = memkv1_text
655
+ kw_args['output_this_layer']['mem_kv'] = (memkv0, memkv1)
656
+
657
+
658
+ else:
659
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
660
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
661
+ new_memkv0 = torch.cat((k0, v0), dim=-1)
662
+ old_k0, old_v0 = mems[0][layer_id][..., :hidden_size], mems[0][layer_id][..., hidden_size:]
663
+
664
+ context_frame_local_text = attention_localframe_and_text_AR(
665
+ q0,
666
+ torch.cat((old_k0.expand(k0.shape[0], -1, -1), k0), dim=-2),
667
+ torch.cat((old_v0.expand(v0.shape[0], -1, -1), v0), dim=-2),
668
+ n_head=attn_module.num_attention_heads_per_partition,
669
+ text_len=text_len,
670
+ frame_len=frame_len,
671
+ frame_num=None,
672
+ log_text_attention_weights=log_text_attention_weights,
673
+ layer_id=layer_id,
674
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
675
+ )
676
+
677
+ old_memkv1 = mems[1][layer_id] if mems[1] is not None else None
678
+
679
+ context_frame_swin, new_memkv1 = self.get_mixin('attention_plus').attention_extra_AR_inference(hidden_states,
680
+ old_memkv1[..., text_len:, :] if old_memkv1.shape[-2]>text_len else None,
681
+ counter-text_len,
682
+ layer_id,
683
+ memkv_text=old_memkv1[..., :text_len, :],
684
+ log_text_attention_weights=log_text_attention_weights)
685
+ if not enforce_no_swin:
686
+ attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
687
+ attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
688
+ output = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
689
+ +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
690
+ else:
691
+ output = attn_module.dense(context_frame_local_text)
692
+
693
+ kw_args['output_this_layer']['mem_kv'] = (new_memkv0, new_memkv1)
694
+
695
+ return output
src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_model.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : cogvideo_model.py
4
+ @Time : 2022/07/11 16:12:05
5
+ @Author : Wenyi Hong
6
+ @Version : 1.0
7
+ @Contact : [email protected]
8
+ '''
9
+
10
+ # here put the import lib
11
+
12
+ import torch
13
+ from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
14
+
15
+ from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
16
+ from SwissArmyTransformer.model.transformer import unscaled_init_method
17
+ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
18
+ import torch.nn.functional as F
19
+ from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
20
+ import math
21
+
22
+ class PositionEmbeddingMixin(BaseMixin):
23
+ def __init__(self, additional_sequence_length, hidden_size,
24
+ init_method_std=0.02, reinit_slice=slice(512, 912),
25
+ ):
26
+ super(PositionEmbeddingMixin, self).__init__()
27
+ self.reinit_slice = reinit_slice
28
+ self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
29
+ torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
30
+
31
+ def reinit(self, parent_model=None):
32
+ old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
33
+ old_len, hidden_size = old_weights.shape
34
+ assert hidden_size == self.position_embeddings.weight.shape[-1]
35
+ self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
36
+
37
+ def window_partition(x, window_size):
38
+ """
39
+ Args:
40
+ x: (B, framenum, H, W, C)
41
+ window_size (int): window size
42
+ Returns:
43
+ windows: (num_windows*B, frame_num, window_size, window_size, C)
44
+ """
45
+ B, framenum, H, W, C = x.shape
46
+ x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
47
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
48
+ return windows
49
+
50
+ def window_reverse(windows, window_size, H, W):
51
+ """
52
+ Args:
53
+ windows: (num_windows*B, frame_num, window_size, window_size, C)
54
+ window_size (int): Window size
55
+ H (int): Height of image
56
+ W (int): Width of image
57
+ Returns:
58
+ x: (B, frame_num, H, W, C)
59
+ """
60
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
61
+ framenum = windows.shape[1]
62
+ x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
63
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
64
+ return x
65
+
66
+ class WindowAttentionMixin(BaseMixin):
67
+ def __init__(self, num_layers,
68
+ hidden_size,
69
+ frame_resolution,
70
+ window_size,
71
+ shift_size,
72
+ n_head,
73
+ frame_num,
74
+ init_method=unscaled_init_method(0.02),
75
+ output_layer_init_method=unscaled_init_method(0.02),
76
+ ):
77
+ super(WindowAttentionMixin, self).__init__()
78
+ self.num_layers = num_layers # replace attention in the LAST n layers
79
+ self.query_key_value = torch.nn.ModuleList(
80
+ [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
81
+ gather_output=False,init_method=init_method)
82
+ for layer_id in range(num_layers)
83
+ ])
84
+ self.dense = torch.nn.ModuleList(
85
+ [RowParallelLinear(
86
+ hidden_size,
87
+ hidden_size,
88
+ input_is_parallel=True,
89
+ init_method=output_layer_init_method,
90
+ bias=True,
91
+ module=self,
92
+ name="dense",
93
+ )
94
+ for layer_id in range(num_layers)
95
+ ])
96
+
97
+ self.n_head = n_head
98
+ self.window_size = window_size
99
+ self.frame_resolution = frame_resolution
100
+ self.frame_len = frame_resolution * frame_resolution
101
+ assert frame_resolution % window_size == 0
102
+ assert 0 < shift_size < window_size
103
+ nW = (self.frame_resolution // self.window_size) ** 2
104
+ ws_squre = self.window_size * self.window_size
105
+
106
+ # odd non-shift, even shift
107
+ img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
108
+ h_slices = (slice(0, -shift_size),
109
+ slice(-shift_size, None))
110
+ w_slices = (slice(0, -shift_size),
111
+ slice(-shift_size, None))
112
+ cnt = 0
113
+ for h in h_slices:
114
+ for w in w_slices:
115
+ img_mask[:, :, h, w, :] = cnt
116
+ cnt += 1
117
+ mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
118
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
119
+ sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
120
+ sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
121
+ attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
122
+
123
+ self.attn_mask_sequential = attn_mask.clone().tril()
124
+ self.causal_mask_sequential = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num).tril()
125
+
126
+ self.causal_mask_interp = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num)
127
+ self.attn_mask_interp = attn_mask.clone()
128
+
129
+ # bi-dir
130
+ for bi_idx in range(0, frame_num, 2):
131
+ for uni_idx in range(1, frame_num, 2):
132
+ self.attn_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
133
+ self.causal_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
134
+ # uni-dir
135
+ for uni_idx in range(1, frame_num, 2):
136
+ self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
137
+ self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
138
+ for uni_idx2 in range(uni_idx+2, frame_num, 2):
139
+ self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
140
+ self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
141
+
142
+ # expand dim
143
+ self.attn_mask_sequential = self.attn_mask_sequential[None, None, :, None]
144
+ self.attn_mask_interp = self.attn_mask_interp[None, None, :, None]
145
+ self.causal_mask_sequential = self.causal_mask_sequential[None, None, :, None]
146
+ self.causal_mask_interp = self.causal_mask_interp[None, None, :, None]
147
+
148
+ self.shift_sizes = [0, shift_size]
149
+ # self.register_buffer("attn_mask", attn_mask)
150
+ # self.register_buffer("causal_mask", causal_mask)
151
+ self.mask_initialized = False
152
+
153
+ self.attn_distribution = torch.nn.ParameterList([
154
+ torch.nn.Parameter(torch.zeros(hidden_size))
155
+ for _ in range(num_layers)
156
+ ])
157
+
158
+ def reinit(self, *pre_mixins):
159
+ start_layer = len(self.transformer.layers) - self.num_layers
160
+ assert start_layer >= 0
161
+ for layer_id in range(self.num_layers):
162
+ old_attention = self.transformer.layers[start_layer + layer_id].attention
163
+ self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
164
+ self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
165
+
166
+ def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
167
+ text_attn_mask=None, mode_sequential=True):
168
+ # pb relax
169
+ swin_pb_relax = True
170
+ alpha = 16
171
+
172
+ # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
173
+ if not self.mask_initialized:
174
+ self.attn_mask_sequential = self.attn_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
175
+ self.causal_mask_sequential = self.causal_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
176
+ self.attn_mask_interp = self.attn_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
177
+ self.causal_mask_interp = self.causal_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
178
+ self.mask_initialized = True
179
+ b0, s1, h0 = frame_hidden_state.shape
180
+ h = h0 // self.n_head
181
+ frame_len = self.frame_resolution * self.frame_resolution
182
+ frame_num = s1 // frame_len
183
+ assert frame_num*frame_len == s1
184
+ wind_square = self.window_size * self.window_size
185
+ nW = frame_len // wind_square
186
+ bswin = b0 * nW
187
+
188
+ causal_mask = self.causal_mask_sequential if mode_sequential else self.causal_mask_interp
189
+ attn_mask = self.attn_mask_sequential if mode_sequential else self.attn_mask_interp
190
+ if text_hidden_state is not None:
191
+ s0 = text_hidden_state.shape[1]
192
+ qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
193
+ q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
194
+
195
+ # shift
196
+ frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
197
+ if self.shift_sizes[layer_id%2] > 0:
198
+ frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
199
+ # window partition
200
+ frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
201
+ qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
202
+ .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
203
+ q, k, v = qkv[0], qkv[1], qkv[2]
204
+
205
+ # pb-relax
206
+ if swin_pb_relax:
207
+ attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
208
+ else:
209
+ attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
210
+
211
+ if self.shift_sizes[layer_id%2] > 0:
212
+ # attn = attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square) + self.attn_mask.unsqueeze(1).unsqueeze(0)
213
+ attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), attn_mask)\
214
+ - 10000.0 * (1.0 - attn_mask)
215
+ attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
216
+ else:
217
+ attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), causal_mask)\
218
+ - 10000.0 * (1.0 - causal_mask)
219
+ attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
220
+ if swin_pb_relax:
221
+ swin_pb_relax_const = torch.max(attn.reshape(bswin, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
222
+ attn = (attn - swin_pb_relax_const)*alpha
223
+
224
+ if text_hidden_state is None:
225
+ attn = F.softmax(attn, dim=-1)
226
+ if attn_dropout is not None:
227
+ with get_cuda_rng_tracker().fork():
228
+ attn = attn_dropout(attn)
229
+ context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
230
+ else:
231
+ assert text_attn_mask is not None
232
+ text_attn_mask = text_attn_mask.unsqueeze(2).unsqueeze(2)
233
+ # pb-relax
234
+ if swin_pb_relax:
235
+ attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / (math.sqrt(h)*alpha), k_text.unsqueeze(1).transpose(-1, -2))
236
+ attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, -1, self.n_head, 1, 1))*alpha
237
+ else:
238
+ attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
239
+
240
+ attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
241
+ attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
242
+ attn = torch.cat((attn, attn_frame2text), dim=-1)
243
+ attn = F.softmax(attn, dim=-1)
244
+
245
+ if attn_dropout is not None:
246
+ with get_cuda_rng_tracker().fork():
247
+ attn = attn_dropout(attn)
248
+
249
+ context_swin = (torch.matmul(attn[..., :-s0], v) +
250
+ torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
251
+ .reshape(bswin, self.n_head, frame_num*wind_square, h))\
252
+ .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
253
+
254
+ context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
255
+ # reverse cycle shift
256
+ if self.shift_sizes[layer_id%2] > 0:
257
+ context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
258
+ context_swin = context_swin.reshape(b0, s1, h0)
259
+
260
+ return context_swin
261
+
262
+
263
+ class FullAttentionMixin(BaseMixin):
264
+ def __init__(self, num_layers,
265
+ hidden_size,
266
+ frame_resolution,
267
+ n_head,
268
+ frame_num,
269
+ init_method=unscaled_init_method(0.02),
270
+ output_layer_init_method=unscaled_init_method(0.02),
271
+ ):
272
+ super(FullAttentionMixin, self).__init__()
273
+ self.num_layers = num_layers # replace attention in the LAST n layers
274
+ self.query_key_value = torch.nn.ModuleList(
275
+ [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
276
+ gather_output=False,init_method=init_method)
277
+ for layer_id in range(num_layers)
278
+ ])
279
+ self.dense = torch.nn.ModuleList(
280
+ [RowParallelLinear(
281
+ hidden_size,
282
+ hidden_size,
283
+ input_is_parallel=True,
284
+ init_method=output_layer_init_method,
285
+ bias=True,
286
+ module=self,
287
+ name="dense",)
288
+ for layer_id in range(num_layers)
289
+ ])
290
+
291
+ self.n_head = n_head
292
+ self.frame_resolution = frame_resolution
293
+ self.frame_len = frame_resolution * frame_resolution
294
+ self.causal_mask = torch.ones(1, 1, self.frame_len*frame_num, self.frame_len*frame_num).tril()
295
+
296
+ self.mask_initialized = False
297
+
298
+ self.attn_distribution = torch.nn.ParameterList([
299
+ torch.nn.Parameter(torch.zeros(hidden_size))
300
+ for _ in range(num_layers)
301
+ ])
302
+
303
+ def reinit(self, *pre_mixins):
304
+ start_layer = len(self.transformer.layers) - self.num_layers
305
+ assert start_layer >= 0
306
+ for layer_id in range(self.num_layers):
307
+ base_attention = self.transformer.layers[start_layer + layer_id].attention
308
+ self.query_key_value[layer_id].weight.data.copy_(base_attention.query_key_value.weight.data)
309
+ self.query_key_value[layer_id].bias.data.copy_(base_attention.query_key_value.bias.data)
310
+
311
+ def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
312
+ text_attn_mask=None, mode_sequential=False):
313
+ # pb relax
314
+ # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
315
+ assert mode_sequential == True # only
316
+ swin_pb_relax = True
317
+ alpha = 16
318
+
319
+ if not self.mask_initialized:
320
+ self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
321
+ self.mask_initialized = True
322
+ b0, s1, h0 = frame_hidden_state.shape
323
+ h = h0 // self.n_head
324
+ frame_len = self.frame_resolution * self.frame_resolution
325
+ frame_num = s1 // frame_len
326
+ assert frame_num*frame_len == s1
327
+
328
+ qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
329
+ .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
330
+ q, k, v = qkv[0], qkv[1], qkv[2]
331
+
332
+ # frames-to-frames
333
+ if swin_pb_relax:
334
+ attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
335
+ else:
336
+ attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
337
+ attn = torch.mul(attn, self.causal_mask) - 10000.0 * (1.0 - self.causal_mask)
338
+ if swin_pb_relax:
339
+ swin_pb_relax_const = torch.max(attn.reshape(b0, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
340
+ attn = (attn - swin_pb_relax_const)*alpha
341
+
342
+ if text_hidden_state is None:
343
+ attn = F.softmax(attn, dim=-1)
344
+ if attn_dropout is not None:
345
+ with get_cuda_rng_tracker().fork():
346
+ attn = attn_dropout(attn)
347
+ context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
348
+ else:
349
+ # frame-to-text
350
+ assert text_attn_mask is not None
351
+ s0 = text_hidden_state.shape[1]
352
+ qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
353
+ q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
354
+ text_attn_mask = text_attn_mask.unsqueeze(2)
355
+ if swin_pb_relax:
356
+ attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / (math.sqrt(h)*alpha), k_text.transpose(-1, -2))
357
+ attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, self.n_head, 1, 1))*alpha
358
+ else:
359
+ attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / math.sqrt(h), k_text.transpose(-1, -2))
360
+ attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
361
+ attn_frame2text = attn_frame2text.reshape(b0, self.n_head, s1, s0)
362
+
363
+ attn = torch.cat((attn, attn_frame2text), dim=-1)
364
+ attn = F.softmax(attn, dim=-1)
365
+
366
+ if attn_dropout is not None:
367
+ with get_cuda_rng_tracker().fork():
368
+ attn = attn_dropout(attn)
369
+
370
+ context_frame = (torch.matmul(attn[..., :-s0], v) +
371
+ torch.matmul(attn[..., -s0:].reshape(b0, self.n_head,s1, s0), v_text))\
372
+ .permute(0, 2, 1, 3).reshape(b0, s1, h0)
373
+
374
+ return context_frame
375
+
376
+
377
+ def attention_localframe_and_text(q0, k0, v0, attention_mask_totxt, attention_mask_local,
378
+ n_head, text_len, frame_len, frame_num, attention_dropout=None, layer_id=0, **kwargs):
379
+ b, s0, h0 = q0.shape
380
+ s1 = s0 - text_len
381
+ h = h0 // n_head
382
+ assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
383
+ # attention_mask_totxt [b, 1, 1, text_len]
384
+ # attention_mask_local [1, 1, frame_num, frame_len, frame_len]
385
+ # attention_mask: [1, 1, text_len+frame_len, text_len+frame_len]
386
+
387
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
388
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
389
+ k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
390
+ k0T = k0.transpose(-1, -2)
391
+
392
+ # score: any2text
393
+ score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
394
+ score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask_totxt) \
395
+ - 10000.0 * (1.0 - attention_mask_totxt)
396
+ score_any2text_part2 = torch.mul(score_any2text[..., text_len:, :], attention_mask_totxt) - \
397
+ 10000.0 * (1.0 - attention_mask_totxt)
398
+
399
+ # score: frame local
400
+ q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
401
+ v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
402
+ k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
403
+ score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
404
+ score_frame_local0 = torch.mul(score_frame_local0, attention_mask_local) \
405
+ - 10000.0 * (1.0 - attention_mask_local)
406
+
407
+ # context for frame
408
+ score_frame_all = torch.cat((score_any2text_part2,
409
+ score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
410
+ attention_probs_frame = F.softmax(score_frame_all, dim=-1)
411
+
412
+ if attention_dropout is not None:
413
+ with get_cuda_rng_tracker().fork():
414
+ attention_probs_frame = attention_dropout(attention_probs_frame)
415
+
416
+ context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
417
+ context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
418
+ view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
419
+ context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
420
+
421
+ # context for text
422
+ attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
423
+ if attention_dropout is not None:
424
+ with get_cuda_rng_tracker().fork():
425
+ attention_probs_text = attention_dropout(attention_probs_text)
426
+ context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
427
+ context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
428
+
429
+ return context_text2text, context_frame
430
+
431
+
432
+ class CogVideoModel(BaseModel):
433
+ def __init__(self, args, transformer=None, parallel_output=True):
434
+ super().__init__(args, transformer=transformer, parallel_output=parallel_output)
435
+ self.stage = args.cogvideo_stage # 1 or 2
436
+ self.mode_sequential = True if self.stage==1 else False
437
+ self.layout = args.layout # [64, 64+400, 64+5*400]
438
+ self.n_head = args.num_attention_heads
439
+ frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
440
+ frame_num = (args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0])
441
+ frame_len = self.layout[1]-self.layout[0]
442
+
443
+ self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
444
+ args.additional_seqlen, args.hidden_size
445
+ ))
446
+
447
+ if args.window_size == -1:
448
+ # full attention
449
+ assert self.stage == 1
450
+ self.add_mixin('attention_plus', FullAttentionMixin(
451
+ num_layers=args.num_layers,
452
+ hidden_size=args.hidden_size,
453
+ frame_resolution=frame_resolution,
454
+ n_head=args.num_attention_heads,
455
+ frame_num=frame_num,
456
+ ))
457
+ else:
458
+ self.add_mixin('attention_plus', WindowAttentionMixin(
459
+ num_layers=args.num_layers,
460
+ hidden_size=args.hidden_size,
461
+ frame_resolution=frame_resolution,
462
+ window_size=args.window_size,
463
+ shift_size=args.window_size//2,
464
+ n_head=args.num_attention_heads,
465
+ frame_num=frame_num,
466
+ ))
467
+ # attention_mask_local
468
+ self.attention_mask_local_sequential = torch.ones(1, 1, frame_num, frame_len, frame_len).tril().unsqueeze(0)
469
+ self.attention_mask_local_interp = torch.ones(1, 1, frame_num, frame_len, frame_len)
470
+
471
+ for idx in range(1, frame_num, 2):
472
+ self.attention_mask_local_interp[:, :, idx:idx+1].tril_()
473
+ self.attention_mask_local_interp = self.attention_mask_local_interp.unsqueeze(0)
474
+ self.mask_initialized = False
475
+
476
+ @classmethod
477
+ def add_model_specific_args(cls, parser):
478
+ group = parser.add_argument_group('CogVideoModel', 'CogVideo model configurations')
479
+ group.add_argument("--layout", type=str, default='64, 464, 2064', help='text_len, textlen+frame_len, textlen+frame_len*frame_num')
480
+ group.add_argument("--window-size", type=int, default=10, help="swin attention's window size in temperal channel, -1 represents full attention")
481
+ group.add_argument("--additional-seqlen", type=int, default=2000)
482
+ group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2])
483
+ return parser
484
+
485
+ def disable_untrainable_params(self):
486
+ self.transformer.requires_grad_(False)
487
+
488
+ def position_embedding_forward(self, position_ids, **kw_args):
489
+ position = position_ids[..., :(64+400)]
490
+ position_plus = position_ids[..., (64+400):]
491
+ position_embeddings = torch.cat(
492
+ (
493
+ self.transformer.position_embeddings(position),
494
+ self.get_mixin('extra_position_embedding').position_embeddings(position_plus-(512+400))
495
+ ),
496
+ dim=-2
497
+ )
498
+ return position_embeddings
499
+
500
+ def attention_forward(self, hidden_states, mask, layer_id, **kw_args):
501
+ # mask.shape=[bs, 1, 1, 64]
502
+ if not self.mask_initialized:
503
+ self.attention_mask_local_sequential = self.attention_mask_local_sequential.to(device=hidden_states.device, dtype=hidden_states.dtype)
504
+ self.attention_mask_local_interp = self.attention_mask_local_interp.to(device=hidden_states.device, dtype=hidden_states.dtype)
505
+ self.mask_initialized = True
506
+
507
+ attn_module = self.transformer.layers[layer_id].attention
508
+ hidden_size = hidden_states.shape[-1]
509
+ bs = hidden_states.shape[0]
510
+
511
+ # base model qkv
512
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
513
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
514
+ dropout_fn = self.transformer.layers[layer_id].attention.attention_dropout if self.training else None
515
+
516
+ attention_mask_local = self.attention_mask_local_sequential if self.mode_sequential else self.attention_mask_local_interp
517
+ context_text, context_frame_local_text = attention_localframe_and_text(
518
+ q0, k0, v0,
519
+ attention_mask_totxt=mask,
520
+ attention_mask_local=attention_mask_local,
521
+ n_head=attn_module.num_attention_heads_per_partition,
522
+ text_len=self.layout[0],
523
+ frame_len=self.layout[1]-self.layout[0],
524
+ frame_num=(self.layout[2]-self.layout[0])//(self.layout[1]-self.layout[0]),
525
+ attention_dropout=dropout_fn,
526
+ layer_id=layer_id,
527
+ )
528
+
529
+ context_frame_swin = self.get_mixin('attention_plus').attention_extra(
530
+ hidden_states[:, self.layout[0]:], layer_id, dropout_fn,
531
+ text_hidden_state=hidden_states[:, :self.layout[0]],
532
+ text_attn_mask=mask[..., 0, :],
533
+ mode_sequential=self.mode_sequential)
534
+
535
+ attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
536
+ attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
537
+
538
+ output_text = attn_module.dense(context_text)
539
+ output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
540
+ +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
541
+ output = torch.cat((output_text, output_frame), dim=-2)
542
+
543
+ return output
src/videogen_hub/pipelines/cogvideo/cogvideo_src/pretrain_cogvideo.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : pretrain_cogvideo.py
4
+ @Time : 2021/10/06 00:58:32
5
+ @Author : Wenyi Hong
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+ import argparse
16
+ import numpy as np
17
+ from videogen_hub.depend.icetk import icetk as tokenizer
18
+ tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
19
+
20
+ from models.cogvideo_model import CogVideoModel
21
+ from SwissArmyTransformer import mpu, get_args
22
+ from SwissArmyTransformer.training.deepspeed_training import training_main
23
+ from SwissArmyTransformer.data_utils import BinaryDataset
24
+
25
+ def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None):
26
+ # Extract batch size and sequence length.
27
+ batch_size, seq_length = data.size()
28
+ assert attention_mask_totxt is not None
29
+ layout = args.layout
30
+ assert seq_length == layout[-1]
31
+ n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long()
32
+ frame_len = layout[1]-layout[0]
33
+ position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long,
34
+ device=data.device)
35
+ for i in range(batch_size):
36
+ torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]],
37
+ dtype=torch.long, device=data.device)
38
+ torch.arange(512, 512+layout[2]-layout[0],
39
+ out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device)
40
+ return position_ids
41
+
42
+
43
+ def get_batch(data_iterator, args, timers):
44
+ # Items and their type.
45
+ keys = ['text', 'loss_mask', 'attention_mask_totxt']
46
+ datatype = torch.int64
47
+
48
+ # Broadcast data.
49
+ timers('data loader').start()
50
+ if data_iterator is not None:
51
+ data = next(data_iterator)
52
+ else:
53
+ data = None
54
+ timers('data loader').stop()
55
+
56
+ data_b = mpu.broadcast_data(keys, data, datatype)
57
+ # Unpack.
58
+ tokens_ = data_b['text'].long()
59
+ loss_mask = data_b['loss_mask'].float()
60
+ attention_mask_totxt = data_b['attention_mask_totxt'].float()
61
+
62
+ labels = tokens_[:, 1:].clone().contiguous()
63
+ loss_mask = loss_mask[:, 1:].contiguous()
64
+ tokens = tokens_[:, :-1].clone().contiguous()
65
+
66
+ for idx in range(args.layout[0], args.layout[2], 400):
67
+ tokens[:, idx] = tokenizer['<start_of_image>']
68
+ # Get the masks and postition ids.
69
+ position_ids = get_masks_and_position_ids_video(
70
+ tokens,
71
+ attention_mask_totxt=attention_mask_totxt,
72
+ args=args
73
+ )
74
+ attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1)
75
+ # Convert
76
+ if args.fp16:
77
+ attention_mask_totxt = attention_mask_totxt.half()
78
+ return tokens, labels, loss_mask, attention_mask_totxt, position_ids
79
+
80
+
81
+ def forward_step(data_iterator, model, args, timers):
82
+ """Forward step."""
83
+
84
+ # Get the batch.
85
+ timers('batch generator').start()
86
+ tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch(
87
+ data_iterator, args, timers)
88
+ timers('batch generator').stop()
89
+
90
+ # Forward model.
91
+ logits, *mems = model(tokens, position_ids, attention_mask_totxt)
92
+ # ======= hyper params =======#
93
+ perframe_len = 400
94
+ text_len=64
95
+ frame_num = 5
96
+ logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous()
97
+ losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:])
98
+ # scaling loss mask
99
+ loss_mask = loss_mask[:, text_len:].reshape(-1)
100
+
101
+ losses_1d = losses.reshape(-1) * loss_mask
102
+ loss = torch.sum(losses_1d) / loss_mask.sum()
103
+ # ===================== Log partial losses ======================== #
104
+ log_loss_dict = {}
105
+ bs = losses.shape[0]
106
+
107
+ if args.cogvideo_stage == 1:
108
+ for i in range(frame_num):
109
+ log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
110
+ else:
111
+ for i in range(1, frame_num-1):
112
+ log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
113
+
114
+ # ===================== END OF BLOCK ======================= #
115
+ return loss, log_loss_dict
116
+
117
+
118
+ def create_dataset_function(path, args):
119
+ dataset_layout = [64, 464, 2064]
120
+ input_layout = [64, 464, 2064]
121
+ # frame_num = 6
122
+ # frame_interval = 2 # DEBUG!!!
123
+ def process_fn(row):
124
+ row = row.astype(np.int64)
125
+ text = row[:dataset_layout[0]]
126
+ frames = row[dataset_layout[0]:]
127
+
128
+ if text[0] == tokenizer['<pad>']:
129
+ text = text[1:] # due to our way of data processing
130
+ if args.cogvideo_stage == 1:
131
+ text, loss_mask, frames = make_text_video_generation(text, frames)
132
+ else:
133
+ text, loss_mask, frames = mask_video_frame_interpolation(text, frames)
134
+
135
+ n_pad = input_layout[0] - len(text)
136
+ parts = [
137
+ np.array([tokenizer['<pad>']] * n_pad, dtype=np.int64),
138
+ text,
139
+ np.array([tokenizer['<start_of_image>']], dtype=np.int64),
140
+ frames,
141
+ ]
142
+ ret = np.concatenate(parts, axis=0)
143
+
144
+ attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad))
145
+ return {'text': ret,
146
+ 'loss_mask': loss_mask,
147
+ 'attention_mask_totxt': attention_mask_totxt,
148
+ }
149
+ return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1])
150
+
151
+ def make_text_video_generation(text, frames):
152
+ input_layout = [64, 464, 2064]
153
+ text = text[text!= tokenizer['<pad>']][:input_layout[0]] # dataset format: 1.0秒<n>{text}<pad><pad> ...
154
+ loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位
155
+ return text, loss_mask, frames
156
+
157
+ def mask_video_frame_interpolation(text, frames):
158
+ input_layout = [64, 464, 2064]
159
+ frame_len = input_layout[1]-input_layout[0]
160
+ # text format: <pad> 1.0秒 <n> {text} <pad> <pad>
161
+ text = text[text!= tokenizer['<pad>']][:input_layout[0]]
162
+ loss_mask = np.array([0] * (input_layout[1]+1)
163
+ + [1] * (input_layout[1]-input_layout[0])
164
+ + [0] * (input_layout[1]-input_layout[0])
165
+ + [1] * (input_layout[1]-input_layout[0])
166
+ + [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位
167
+
168
+ return text, loss_mask, frames
169
+
170
+
171
+
172
+ if __name__ == '__main__':
173
+ py_parser = argparse.ArgumentParser(add_help=False)
174
+ py_parser.add_argument('--txt-loss-scale', type=float, default=1)
175
+ CogVideoModel.add_model_specific_args(py_parser)
176
+
177
+ known, args_list = py_parser.parse_known_args()
178
+
179
+ args = get_args(args_list)
180
+ args = argparse.Namespace(**vars(args), **vars(known))
181
+
182
+ args.layout = [int(x) for x in args.layout.split(',')]
183
+
184
+ training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
src/videogen_hub/pipelines/cogvideo/cogvideo_src/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ SwissArmyTransformer==0.2.9
2
+ icetk
3
+ gifmaker
4
+ torchvision
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : __init__.py
4
+ @Time : 2022/03/02 13:57:09
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+
15
+ from .direct_sr import DirectSuperResolution
16
+ from .iterative_sr import IterativeSuperResolution
17
+ from .sr_group import SRGroup
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/cluster_label2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b87880fdbe89670f12844377b9cf97a9733b1f54e3a9b73cbb9835084c4e02ec
3
+ size 160128
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/direct_sr.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : direct_sr.py
4
+ @Time : 2022/03/02 13:58:11
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+
16
+ # -*- encoding: utf-8 -*-
17
+ '''
18
+ @File : inference_cogview2.py
19
+ @Time : 2021/10/10 16:31:34
20
+ @Author : Ming Ding
21
+ @Contact : [email protected]
22
+ '''
23
+
24
+ # here put the import lib
25
+ import os
26
+ import sys
27
+ import math
28
+ import random
29
+ from PIL import ImageEnhance, Image
30
+
31
+ import torch
32
+ import argparse
33
+ from torchvision import transforms
34
+
35
+ from SwissArmyTransformer import get_args
36
+ from SwissArmyTransformer.training.model_io import load_checkpoint
37
+ from .dsr_sampling import filling_sequence_dsr, IterativeEntfilterStrategy
38
+ from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
39
+
40
+ from .dsr_model import DsrModel
41
+
42
+ from videogen_hub.depend.icetk import icetk as tokenizer
43
+
44
+ class DirectSuperResolution:
45
+ def __init__(self, args, path, max_bz=4, topk=6, onCUDA=False):
46
+ args.load = path
47
+ args.kernel_size = 5
48
+ args.kernel_size2 = 5
49
+ args.new_sequence_length = 4624
50
+ args.layout = [96,496,4096]
51
+
52
+ model = DsrModel(args)
53
+ if args.fp16:
54
+ model = model.half()
55
+
56
+ load_checkpoint(model, args) # on cpu
57
+ model.eval()
58
+ self.model = model
59
+ self.onCUDA = onCUDA
60
+ if onCUDA:
61
+ self.model = self.model.cuda()
62
+
63
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
64
+
65
+ self.strategy = IterativeEntfilterStrategy(invalid_slices,
66
+ temperature=1.0, topk=topk) # temperature not used # Temperature Freezed Here!!
67
+ self.max_bz = max_bz
68
+
69
+ def __call__(self, text_tokens, image_tokens, enhance=False):
70
+ if len(text_tokens.shape) == 1:
71
+ text_tokens.unsqueeze_(0)
72
+ if len(image_tokens.shape) == 1:
73
+ image_tokens.unsqueeze_(0)
74
+ # ===================== Debug ======================== #
75
+ # new_image_tokens = []
76
+ # for small_img in image_tokens:
77
+ # decoded = tokenizer.decode(image_ids=small_img)
78
+ # decoded = torch.nn.functional.interpolate(decoded, size=(480, 480)).squeeze(0)
79
+ # ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
80
+ # image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
81
+ # small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
82
+ # new_image_tokens.append(small_img2)
83
+ # image_tokens = torch.stack(new_image_tokens)
84
+ # return image_tokens
85
+ # ===================== END OF BLOCK ======================= #
86
+ if enhance:
87
+ new_image_tokens = []
88
+ for small_img in image_tokens:
89
+ decoded = tokenizer.decode(image_ids=small_img).squeeze(0)
90
+ ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
91
+ image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
92
+ small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.), image_size=160).view(-1)
93
+ new_image_tokens.append(small_img2)
94
+ image_tokens = torch.stack(new_image_tokens)
95
+
96
+ seq = torch.cat((text_tokens,image_tokens), dim=1)
97
+ seq1 = torch.tensor([tokenizer['<start_of_image>']]*3601, device=image_tokens.device).unsqueeze(0).expand(text_tokens.shape[0], -1)
98
+ if not self.onCUDA:
99
+ print('Converting Dsr model...')
100
+ model = self.model.cuda()
101
+ else:
102
+ model = self.model
103
+ print('Direct super-resolution...')
104
+ output_list = []
105
+ for tim in range(max((text_tokens.shape[0]+self.max_bz-1) // self.max_bz, 1)):
106
+ output1 = filling_sequence_dsr(model,
107
+ seq[tim*self.max_bz:(tim+1)*self.max_bz],
108
+ seq1[tim*self.max_bz:(tim+1)*self.max_bz],
109
+ warmup_steps=1, block_hw=(1, 0),
110
+ strategy=self.strategy
111
+ )
112
+ output_list.extend(output1[1:])
113
+ if not self.onCUDA:
114
+ print('Moving back Dsr to cpu...')
115
+ model = model.cpu()
116
+ torch.cuda.empty_cache()
117
+ return torch.cat(output_list, dim=0)
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_model.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : cuda2d_model.py
4
+ @Time : 2021/10/02 01:36:32
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
19
+
20
+ from SwissArmyTransformer.model.transformer import split_tensor_along_last_dim, unscaled_init_method
21
+ from SwissArmyTransformer.mpu.utils import sqrt
22
+ from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
23
+ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
24
+
25
+ class PositionEmbeddingMixin(BaseMixin):
26
+ def __init__(self, additional_sequence_length, hidden_size,
27
+ init_method_std=0.02, reinit_slice=slice(512, 512+400)
28
+ ):
29
+ super(PositionEmbeddingMixin, self).__init__()
30
+ self.reinit_slice = reinit_slice
31
+ self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
+ torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
+
34
+ def reinit(self, parent_model=None):
35
+ old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
+ old_len, hidden_size = old_weights.shape
37
+ assert hidden_size == self.position_embeddings.weight.shape[-1]
38
+ old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
39
+ assert new_edge % old_edge == 0
40
+ self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
41
+ # self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
42
+
43
+
44
+ class AttentionMixin(BaseMixin):
45
+ def __init__(self, num_layers,
46
+ hidden_size,
47
+ init_method=unscaled_init_method(0.02),
48
+ output_layer_init_method=unscaled_init_method(0.02)
49
+ ):
50
+ super(AttentionMixin, self).__init__()
51
+ self.num_layers = num_layers # replace attention in the LAST n layers
52
+ self.query_key_value = torch.nn.ModuleList(
53
+ [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3,
54
+ gather_output=False, init_method=init_method)
55
+ for layer_id in range(num_layers)
56
+ ])
57
+ self.dense = torch.nn.ModuleList(
58
+ [RowParallelLinear(hidden_size,
59
+ hidden_size,
60
+ input_is_parallel=True,
61
+ init_method=output_layer_init_method)
62
+ for layer_id in range(num_layers)
63
+ ])
64
+
65
+ def reinit(self, parent_model=None):
66
+ start_layer = len(self.transformer.layers) - self.num_layers
67
+ assert start_layer >= 0
68
+ for layer_id in range(self.num_layers):
69
+ old_attention = self.transformer.layers[start_layer + layer_id].attention
70
+ self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
71
+ self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
72
+ self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
73
+ self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)
74
+
75
+ class DsrModel(BaseModel):
76
+ def __init__(self, args, transformer=None):
77
+ super().__init__(args, transformer=transformer)
78
+ self.original_sequence_length = args.max_sequence_length
79
+ additional_seqlen = args.new_sequence_length - args.max_sequence_length
80
+ self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
81
+ additional_seqlen, args.hidden_size
82
+ ))
83
+ self.add_mixin('attention_plus', AttentionMixin(
84
+ num_layers=args.num_layers,
85
+ hidden_size=args.hidden_size
86
+ ))
87
+ self.layout = args.layout
88
+ # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
89
+ self.kernel_size = args.kernel_size
90
+ self.kernel_size2 = args.kernel_size2
91
+ self.log_attention_weights = None
92
+
93
+ def position_embedding_forward(self, position_ids, **kw_args):
94
+ position = position_ids[..., :self.layout[1]]
95
+ position_plus = position_ids[..., self.layout[1]:] - self.original_sequence_length
96
+ position_embeddings = torch.cat(
97
+ (
98
+ self.transformer.position_embeddings(position),
99
+ self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
100
+ ),
101
+ dim=-2
102
+ )
103
+ return position_embeddings
104
+
105
+ def attention_forward(self, hidden_states, mask,
106
+ layer_id=None, log_attention_weights=None, **kw_args):
107
+ attn_module = self.transformer.layers[layer_id].attention
108
+ # attention_plus on all layers
109
+ query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id]
110
+ dense_plus = self.get_mixin('attention_plus').dense[layer_id]
111
+ # split two parts
112
+ hidden_states_plus = hidden_states[:, self.layout[1]:]
113
+ hidden_states = hidden_states[:, :self.layout[1]]
114
+ # base model qkv
115
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
116
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
117
+ # cuda2d model qkv
118
+ mixed_raw_layer = query_key_value_plus(hidden_states_plus)
119
+ q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3)
120
+
121
+ dropout_fn = attn_module.attention_dropout if self.training else None
122
+
123
+ # cuda2d attention
124
+ context_layer0, context_layer1 = sparse_attention_2d_light(
125
+ q0, k0, v0,
126
+ q1, k1, v1,
127
+ mask,
128
+ n_head=attn_module.num_attention_heads_per_partition,
129
+ text_len=self.layout[0],
130
+ kernel_size=self.kernel_size,
131
+ kernel_size2=self.kernel_size2,
132
+ attention_dropout=dropout_fn,
133
+ log_attention_weights=log_attention_weights,
134
+ add_scalar=(kw_args['add_scalar'] if 'add_scalar' in kw_args else 0)
135
+ )
136
+
137
+ output_0 = attn_module.dense(context_layer0)
138
+ output_1 = dense_plus(context_layer1)
139
+ output = torch.cat((output_0, output_1), dim=1)
140
+
141
+ return output
142
+
143
+ def final_forward(self, logits, **kwargs):
144
+ logits_parallel = logits
145
+ logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
146
+ # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
147
+ return logits_parallel
148
+
149
+ def disable_untrainable_params(self):
150
+ self.transformer.requires_grad_(False)
151
+
152
+ @classmethod
153
+ def add_model_specific_args(cls, parser):
154
+ group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
155
+ group.add_argument("--kernel-size", type=int, default=5)
156
+ group.add_argument("--kernel-size2", type=int, default=5)
157
+ group.add_argument("--layout", type=str, default='96,496,4096')
158
+ group.add_argument("--new-sequence-length", type=int, default=4096)
159
+ return parser
160
+
161
+ def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, add_scalar=0, **kwargs):
162
+ '''
163
+ q0, k0, v0: [batch_size, 1088, hidden_size]
164
+ q1, k1, v1: [batch_size, 4096, h2]
165
+ n_head: int
166
+ attention_mask: [batch_size, 1088, 1088]
167
+ '''
168
+ from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
169
+
170
+ b, s0, h0 = q0.shape
171
+ b, s1, h1 = q1.shape
172
+ h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1)
173
+
174
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
175
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
176
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
177
+
178
+ # standard attention for level 0
179
+ attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
180
+
181
+ if log_attention_weights is not None:
182
+ attention_scores += log_attention_weights
183
+ attention_scores = torch.mul(attention_scores, attention_mask) - \
184
+ 10000.0 * (1.0 - attention_mask)
185
+
186
+ attention_probs0 = F.softmax(attention_scores, dim=-1)
187
+
188
+ # local attention for level 1
189
+ q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
190
+ k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
191
+ v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
192
+ # scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)
193
+ scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
194
+
195
+ # cross attention
196
+ k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
197
+ scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
198
+ scores_1 = torch.cat(
199
+ (
200
+ scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]) + add_scalar,
201
+ scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
202
+ ),
203
+ dim=-1)
204
+ attention_probs1 = F.softmax(scores_1, dim=-1)
205
+
206
+ if attention_dropout is not None:
207
+ # with get_cuda_rng_tracker().fork():
208
+ attention_probs0 = attention_dropout(attention_probs0)
209
+ attention_probs1 = attention_dropout(attention_probs1)
210
+
211
+ # weighting for level 0
212
+ context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
213
+ # weighting for level 1
214
+ probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
215
+ # context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
216
+ context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
217
+
218
+ context1 = context1_to_1.view(b, n_head * h, l1**2)
219
+ # weighting for cross attention
220
+ probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
221
+ v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
222
+ context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
223
+ context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
224
+ context1 = context1 + context1_to_0
225
+ return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_sampling.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ """
3
+ @File : cuda2d_sampling.py
4
+ @Time : 2021/10/09 00:46:04
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ """
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ from cv2 import reduce
15
+ import torch
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+
21
+
22
+ def top_k_logits_(logits, top_k=0, filter_value=-float("Inf")):
23
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
24
+ logits[indices_to_remove] = filter_value
25
+ return logits
26
+
27
+
28
+ class IterativeEntfilterStrategy:
29
+ def __init__(self, invalid_slices=[], temperature=1.0, topk=6):
30
+ self.invalid_slices = invalid_slices
31
+ self.temperature = temperature
32
+ self.topk = topk
33
+ device = "cpu"
34
+ if torch.cuda.is_available():
35
+ device = "cuda"
36
+ self.cluster_labels = torch.tensor(
37
+ np.load("cluster_label2.npy"), device=device, dtype=torch.long
38
+ )
39
+
40
+ def forward(
41
+ self,
42
+ logits_,
43
+ tokens,
44
+ temperature=None,
45
+ entfilter=None,
46
+ filter_topk=5,
47
+ temperature2=None,
48
+ ):
49
+ # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
50
+ if temperature is None:
51
+ temperature = self.temperature
52
+
53
+ logits = logits_.float() / temperature
54
+ for invalid_slice in self.invalid_slices:
55
+ logits[..., invalid_slice] = -float("Inf")
56
+ logits = logits.view(-1, logits.shape[-1])
57
+
58
+ rprobs = F.softmax(logits.float(), dim=-1)
59
+ c = self.cluster_labels.expand(*rprobs.shape)
60
+ cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(
61
+ 1, c, rprobs
62
+ )
63
+
64
+ best_scores, best_clusters = cprobs.topk(self.topk)
65
+ bz = logits.shape[0]
66
+ best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
67
+ sampled_ids = torch.multinomial(best_scores, num_samples=1)
68
+ selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
69
+ selected_mask = (
70
+ self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters
71
+ ) # cluster_labels [1, 20000] \in [0,500)
72
+ logits[selected_mask] = -65504
73
+ # for i in range(bz):
74
+ # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
75
+ # logits[i, self.cluster_labels != selected_cluster] = -65504
76
+
77
+ # logits = top_k_logits(logits, self.topk, self.top_p)
78
+ probs = F.softmax(
79
+ logits.float() / 0.6, dim=-1
80
+ ) # float is essetial, due to a bug in Pytorch
81
+ pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
82
+
83
+ assert tokens.shape[1] == pred.shape[1] + 1
84
+ tokens = torch.cat((tokens[:, :1], pred), dim=1)
85
+ return tokens
86
+
87
+
88
+ def filling_sequence_dsr(
89
+ model,
90
+ seq0,
91
+ seq1,
92
+ warmup_steps=3,
93
+ block_hw=(4, 4),
94
+ strategy=IterativeEntfilterStrategy(topk=10),
95
+ ):
96
+ """
97
+ seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
98
+ 4095 {layout[2]} final_token.
99
+ Attention:
100
+ The sampling temperature are changing, temporally we hard code them here.
101
+ The temperature in the strategy is not used.
102
+ """
103
+ assert hasattr(model, "layout")
104
+ layout = model.layout
105
+ assert (
106
+ len(seq0.shape) == 2 and len(seq1.shape) == 2 and seq0.shape[0] == seq1.shape[0]
107
+ )
108
+ assert len(layout) == 3
109
+ assert seq1.shape[1] == layout[-1] - layout[-2] + 1
110
+ assert (seq1 >= 0).all() and (seq0 >= 0).all()
111
+ device = seq0.device
112
+ # concat and pad sequences
113
+ batch_size = seq0.shape[0]
114
+ n_pad = layout[1] - seq0.shape[1]
115
+ assert n_pad > 0, "You should truncate long input before filling."
116
+ seq = torch.cat(
117
+ (
118
+ torch.tensor([0] * n_pad, device=device, dtype=seq0.dtype)
119
+ .unsqueeze(0)
120
+ .expand(batch_size, n_pad),
121
+ seq0,
122
+ seq1,
123
+ ),
124
+ dim=1,
125
+ ) # [b, layout[-1]+1]
126
+ assert seq.shape[1] == layout[-1] + 1
127
+
128
+ # build initial tokens, attention_mask, and position_ids
129
+ tokens = seq.clone()
130
+ attention_mask = torch.ones(layout[1], layout[1]).to(device)
131
+ attention_mask[: layout[0], layout[0] :] = 0
132
+ attention_mask[n_pad:, :n_pad] = 0
133
+ attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
134
+ position_ids = torch.cat(
135
+ (
136
+ torch.zeros(n_pad, dtype=torch.long),
137
+ torch.arange(0, layout[0] - n_pad),
138
+ torch.arange(513, 513 + layout[1] - layout[0]),
139
+ torch.arange(1024, 1024 + layout[2] - layout[1]),
140
+ )
141
+ ).to(device)
142
+ log_attention_weights = torch.zeros(layout[1], layout[1], device=device).type_as(
143
+ next(model.parameters())
144
+ )
145
+ log_attention_weights[layout[0] :, n_pad : layout[0]] = 0.0
146
+
147
+ # prepare for interation
148
+ unfixed = tokens < 0 # just init an all-False tensor
149
+ unfixed[:, -layout[-1] + layout[-2] :] = True
150
+
151
+ ll, rr = block_hw
152
+ edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
153
+ num_steps = warmup_steps + ll - 1 + rr
154
+ # interative refining
155
+
156
+ # unfixed[..., -(layout[-1] - layout[-2]):].view(
157
+ # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
158
+
159
+ ret = []
160
+ ret.append(tokens[:, layout[-2] + 1 :].clone())
161
+ for step_cnt in range(1, num_steps + 1):
162
+ if step_cnt <= warmup_steps:
163
+ logits, *_dump = model(
164
+ tokens[:, :-1],
165
+ position_ids,
166
+ attention_mask,
167
+ log_attention_weights=log_attention_weights,
168
+ )
169
+ real_temp = 1.0
170
+ new_tokens = strategy.forward(logits, tokens, real_temp)
171
+ tokens[unfixed] = new_tokens[unfixed]
172
+ else:
173
+ logits, *_dump = model(
174
+ tokens[:, :-1],
175
+ position_ids,
176
+ attention_mask,
177
+ log_attention_weights=log_attention_weights,
178
+ )
179
+ real_temp = 1.0
180
+ new_tokens = strategy.forward(
181
+ logits,
182
+ tokens,
183
+ real_temp,
184
+ entfilter=1.3,
185
+ filter_topk=5,
186
+ temperature2=0.6,
187
+ )
188
+ # tokens[unfixed] = new_tokens[unfixed]
189
+ # fixed tokens (update unfixed)
190
+ unfixed2 = tokens > 10000000
191
+ for x in range(min(ll, step_cnt - warmup_steps)):
192
+ y = step_cnt - warmup_steps - x - 1
193
+ if y < rr:
194
+ unfixed[..., -(layout[-1] - layout[-2]) :].view(
195
+ batch_size, edge_len // ll, ll, edge_len // rr, rr
196
+ )[:, :, x, :, y] = False
197
+ unfixed2[..., -(layout[-1] - layout[-2]) :].view(
198
+ batch_size, edge_len // ll, ll, edge_len // rr, rr
199
+ )[:, :, x, :, y] = True
200
+ tokens[unfixed2] = new_tokens[unfixed2]
201
+
202
+ ret.append(tokens[:, layout[-2] + 1 :].clone())
203
+
204
+ return ret
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/iterative_sr.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : iterative_sr.py
4
+ @Time : 2022/03/02 15:57:45
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+
15
+ # here put the import lib
16
+ import os
17
+ import sys
18
+ import math
19
+ import random
20
+ from PIL import ImageEnhance, Image
21
+
22
+ import torch
23
+ import argparse
24
+ from torchvision import transforms
25
+
26
+ from SwissArmyTransformer.training.model_io import load_checkpoint
27
+ from SwissArmyTransformer import get_args
28
+ from .itersr_sampling import filling_sequence_itersr, IterativeEntfilterStrategy
29
+ from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
30
+
31
+ from .itersr_model import ItersrModel
32
+
33
+ from videogen_hub.depend.icetk import icetk as tokenizer
34
+
35
+ class IterativeSuperResolution:
36
+ def __init__(self, args, path, max_bz=4, shared_transformer=None):
37
+ args.load = path
38
+ args.kernel_size = 5
39
+ args.kernel_size2 = 5
40
+ args.new_sequence_length = 4624
41
+ args.layout = [16,3616]
42
+
43
+ model = ItersrModel(args, transformer=shared_transformer)
44
+ if args.fp16:
45
+ model = model.half()
46
+
47
+ load_checkpoint(model, args) # on cpu
48
+ model.eval()
49
+ self.model = model.cuda()
50
+
51
+ # save cpu weights
52
+ self.saved_weights = dict((k,v.cpu())
53
+ for k, v in model.named_parameters()
54
+ if 'transformer' in k
55
+ )
56
+
57
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
58
+
59
+ self.strategy = IterativeEntfilterStrategy(invalid_slices,
60
+ temperature=args.temp_all_itersr, topk=args.topk_itersr)
61
+ self.max_bz = max_bz
62
+
63
+ def _restore_transformer_from_cpu(self, non_blocking=False):
64
+ for k, v in self.model.named_parameters():
65
+ if k in self.saved_weights:
66
+ v.copy_(self.saved_weights[k])
67
+
68
+ def __call__(self, text_tokens, image_tokens, enhance=False, input_mask=None):
69
+ if len(text_tokens.shape) == 1:
70
+ text_tokens.unsqueeze_(0)
71
+ text_tokens = text_tokens.clone()[..., :16]
72
+ if len(image_tokens.shape) == 1:
73
+ image_tokens.unsqueeze_(0)
74
+ if enhance:
75
+ new_image_tokens = []
76
+ for big_img in image_tokens:
77
+ decoded = tokenizer.decode(image_ids=big_img).squeeze(0)
78
+ ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
79
+ image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
80
+ big_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
81
+ new_image_tokens.append(big_img2)
82
+ image_tokens = torch.stack(new_image_tokens)
83
+ print('Converting Itersr model...')
84
+ self._restore_transformer_from_cpu()
85
+ model = self.model
86
+ print('iterative super-resolution...')
87
+ output_list = []
88
+ for tim in range(max(text_tokens.shape[0] // self.max_bz, 1)):
89
+ big_img = image_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
90
+ text_seq = text_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
91
+ mask_raw = torch.tensor(
92
+ [
93
+ -1, 0, 1, 2, 3, 4,
94
+ 0, -1, 2, -1, -2, 5,
95
+ 1, -2, 3, 4, 5, 6,
96
+ 2, 3, 4, 5, -1, 1,
97
+ 3, -1, -2, 0, -1, 2,
98
+ 4, 5, 6, 1, 3, -2
99
+ ]
100
+ ).view(1, 6, 1, 6).expand(10, 6, 10, 6).reshape(-1).contiguous()
101
+
102
+ topks = [60, 40, 40, 40, 20, 20, 10]
103
+
104
+ for mask_ratio in range(1, 7):
105
+ self.strategy.topk = topks[mask_ratio]
106
+ mask = (mask_raw.to(big_img.device) >= mask_ratio)
107
+ if input_mask is not None:
108
+ mask = mask & input_mask
109
+ big_img.masked_fill_(mask, tokenizer['<start_of_image>'])
110
+ seq1 = big_img
111
+ output1 = filling_sequence_itersr(model, text_seq, seq1,
112
+ warmup_steps=1, block_hw=(1, 0),
113
+ strategy=self.strategy
114
+ )
115
+ big_img = output1
116
+ print(f'Iter {mask_ratio} times.')
117
+ output_list.append(output1.clone())
118
+ return torch.cat(output_list, dim=0)
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_model.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : itersr_model.py
4
+ @Time : 2021/10/02 01:36:32
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
19
+
20
+ from SwissArmyTransformer.mpu.utils import sqrt
21
+ from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
22
+ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
23
+ from SwissArmyTransformer.model.transformer import unscaled_init_method, split_tensor_along_last_dim
24
+
25
+ class PositionEmbeddingMixin(BaseMixin):
26
+ def __init__(self, additional_sequence_length, hidden_size,
27
+ init_method_std=0.02, reinit_slice=slice(512, 512+400)
28
+ ):
29
+ super(PositionEmbeddingMixin, self).__init__()
30
+ self.reinit_slice = reinit_slice
31
+ self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
+ torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
+
34
+ def reinit(self, parent_model=None):
35
+ old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
+ old_len, hidden_size = old_weights.shape
37
+ assert hidden_size == self.position_embeddings.weight.shape[-1]
38
+ old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
39
+ assert new_edge % old_edge == 0
40
+ self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
41
+
42
+ class ItersrModel(BaseModel):
43
+ def __init__(self, args, transformer=None):
44
+ super().__init__(args, transformer=transformer)
45
+ self.original_sequence_length = args.max_sequence_length
46
+ additional_seqlen = args.new_sequence_length - args.max_sequence_length
47
+ self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
48
+ additional_seqlen, args.hidden_size
49
+ ))
50
+ # self.add_mixin('attention_plus', AttentionMixin(
51
+ # num_layers=args.num_layers,
52
+ # hidden_size=args.hidden_size
53
+ # ))
54
+ self.layout = args.layout
55
+ # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
56
+ self.kernel_size = args.kernel_size
57
+ self.kernel_size2 = args.kernel_size2
58
+ self.log_attention_weights = None
59
+
60
+ def position_embedding_forward(self, position_ids, **kw_args):
61
+ position = position_ids[..., :self.layout[0]]
62
+ position_plus = position_ids[..., self.layout[0]:] - self.original_sequence_length
63
+ position_embeddings = torch.cat(
64
+ (
65
+ self.transformer.position_embeddings(position),
66
+ self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
67
+ ),
68
+ dim=-2
69
+ )
70
+ return position_embeddings
71
+
72
+ def attention_forward(self, hidden_states, mask,
73
+ layer_id=None, log_attention_weights=None, **kw_args):
74
+ attn_module = self.transformer.layers[layer_id].attention
75
+ # base model qkv
76
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
77
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer[:, :self.layout[0]], 3)
78
+ # cuda2d model qkv
79
+ q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer[:, self.layout[0]:], 3)
80
+
81
+ dropout_fn = attn_module.attention_dropout if self.training else None
82
+
83
+ # cuda2d attention
84
+ context_layer = sparse_attention_2d_text(
85
+ q0, k0, v0,
86
+ q1, k1, v1,
87
+ mask,
88
+ n_head=attn_module.num_attention_heads_per_partition,
89
+ text_len=self.layout[0],
90
+ kernel_size=self.kernel_size,
91
+ attention_dropout=dropout_fn,
92
+ log_attention_weights=log_attention_weights,
93
+ )
94
+
95
+ output = attn_module.dense(context_layer)
96
+
97
+ return output
98
+
99
+ def final_forward(self, logits, **kwargs):
100
+ logits_parallel = logits
101
+ logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]).float()
102
+ # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
103
+ return logits_parallel
104
+
105
+ # def disable_untrainable_params(self):
106
+ # self.transformer.requires_grad_(False)
107
+
108
+ @classmethod
109
+ def add_model_specific_args(cls, parser):
110
+ group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
111
+ group.add_argument("--kernel-size", type=int, default=5)
112
+ group.add_argument("--kernel-size2", type=int, default=5)
113
+ group.add_argument("--layout", type=str, default='16,3616')
114
+ group.add_argument("--new-sequence-length", type=int, default=4096)
115
+ return parser
116
+
117
+ def sparse_attention_2d_text(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
118
+ '''
119
+ q0, k0, v0: [batch_size, 16, hidden_size]
120
+ q1, k1, v1: [batch_size, 3600, hidden_size]
121
+ n_head: int
122
+ attention_mask: [batch_size, 16]
123
+ '''
124
+ from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
125
+ b, s0, h0 = q0.shape
126
+ b, s1, h1 = q1.shape
127
+ h, l1 = h0 // n_head, sqrt(s1)
128
+ assert attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
129
+
130
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
131
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
132
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
133
+
134
+ # standard attention for level 0
135
+ attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
136
+
137
+ attention_scores = torch.mul(attention_scores, attention_mask) - \
138
+ 10000.0 * (1.0 - attention_mask)
139
+
140
+ attention_probs0 = F.softmax(attention_scores, dim=-1)
141
+
142
+ # local attention for level 1
143
+ q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
144
+ k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
145
+ v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
146
+ scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
147
+
148
+ # cross attention
149
+ scores_1_to_0 = torch.matmul(q1.view(b, n_head, h, s1).transpose(-1, -2), k0T)
150
+ if log_attention_weights is not None:
151
+ scores_1_to_0 += log_attention_weights
152
+ scores_1_to_0 = torch.mul(scores_1_to_0, attention_mask) - \
153
+ 10000.0 * (1.0 - attention_mask)
154
+ scores_1 = torch.cat(
155
+ (
156
+ scores_1_to_0.view(b*n_head, s1, s0),
157
+ scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
158
+ ),
159
+ dim=-1)
160
+ attention_probs1 = F.softmax(scores_1, dim=-1)
161
+
162
+ if attention_dropout is not None:
163
+ with get_cuda_rng_tracker().fork():
164
+ attention_probs1 = attention_dropout(attention_probs1)
165
+
166
+ # weighting for level 0
167
+ context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
168
+ # weighting for level 1
169
+ probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
170
+ context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
171
+
172
+ context1 = context1_to_1.view(b, n_head, h, l1**2)
173
+ # weighting for cross attention
174
+ probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view(b, n_head, -1, scores_1_to_0.shape[3])
175
+
176
+ context1_to_0 = torch.matmul(probs_1_to_0, v0)
177
+ context1 = context1.transpose(-1, -2) + context1_to_0
178
+
179
+ output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
180
+
181
+ return output
182
+
183
+ def sparse_attention_2d_notext(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
184
+ '''
185
+ q0, k0, v0: [batch_size, 16, hidden_size]
186
+ q1, k1, v1: [batch_size, 3600, hidden_size]
187
+ n_head: int
188
+ attention_mask: [batch_size, 16]
189
+ '''
190
+ from SwissArmyTransformer.mpu.local_attention_function import f_similar, f_weighting
191
+ b, s0, h0 = q0.shape
192
+ b, s1, h1 = q1.shape
193
+ h, l1 = h0 // n_head, sqrt(s1)
194
+ assert len(attention_mask.shape) == 4 and attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
195
+
196
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
197
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
198
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
199
+
200
+ # standard attention for level 0
201
+ attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
202
+
203
+ attention_scores = torch.mul(attention_scores, attention_mask) - \
204
+ 10000.0 * (1.0 - attention_mask)
205
+
206
+ attention_probs0 = F.softmax(attention_scores, dim=-1)
207
+
208
+ # local attention for level 1
209
+ q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
210
+ k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
211
+ v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
212
+ scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
213
+
214
+ attention_probs1 = F.softmax(scores_1_to_1, dim=-1)
215
+
216
+ if attention_dropout is not None:
217
+ with get_cuda_rng_tracker().fork():
218
+ attention_probs1 = attention_dropout(attention_probs1)
219
+
220
+ # weighting for level 0
221
+ context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
222
+ # weighting for level 1
223
+ probs_1_to_1 = attention_probs1
224
+ context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
225
+
226
+ context1 = context1_to_1.view(b, n_head, h, l1**2)
227
+ # weighting for cross attention
228
+ context1 = context1.transpose(-1, -2)
229
+
230
+ output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
231
+
232
+ return output
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_sampling.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : itersr_sampling.py
4
+ @Time : 2022/03/03 14:24:28
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from videogen_hub.depend.icetk import icetk as tokenizer
19
+
20
+ def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
21
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
22
+ logits[indices_to_remove] = filter_value
23
+ return logits
24
+
25
+ # class IterativeEntfilterStrategy:
26
+ # def __init__(self, invalid_slices=[], temperature=1., topk=10):
27
+ # self.invalid_slices = invalid_slices
28
+ # self.temperature = temperature
29
+ # self.topk = topk
30
+ # self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
31
+
32
+
33
+ # def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
34
+ # # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
35
+ # if temperature is None:
36
+ # temperature = self.temperature
37
+
38
+ # logits = logits_.float() / temperature
39
+ # for invalid_slice in self.invalid_slices:
40
+ # logits[..., invalid_slice] = -float('Inf')
41
+ # logits = logits.view(-1, logits.shape[-1])
42
+
43
+ # rprobs = F.softmax(logits.float(), dim=-1)
44
+ # c = self.cluster_labels.expand(*rprobs.shape)
45
+ # cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
46
+
47
+ # best_scores, best_clusters = cprobs.topk(self.topk)
48
+ # bz = logits.shape[0]
49
+ # best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
50
+ # sampled_ids = torch.multinomial(best_scores, num_samples=1)
51
+ # selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
52
+ # selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500)
53
+ # logits[selected_mask] = -65504
54
+ # # for i in range(bz):
55
+ # # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
56
+ # # logits[i, self.cluster_labels != selected_cluster] = -65504
57
+
58
+ # # logits = top_k_logits(logits, self.topk, self.top_p)
59
+ # probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
60
+ # pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
61
+
62
+ # assert tokens.shape[1] == pred.shape[1]
63
+ # tokens = pred
64
+ # return tokens
65
+
66
+ class IterativeEntfilterStrategy:
67
+ def __init__(self, invalid_slices=[], temperature=1., topk=10):
68
+ self.invalid_slices = invalid_slices
69
+ self.temperature = temperature
70
+ self.topk = topk
71
+
72
+ def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
73
+ # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
74
+ if temperature is None:
75
+ temperature = self.temperature
76
+ # check entropy filter
77
+ # if entfilter is not None:
78
+ # assert temperature2 is not None
79
+ # topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1)
80
+ # ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length]
81
+ # temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2
82
+
83
+ logits = logits.float() / temperature
84
+ for invalid_slice in self.invalid_slices:
85
+ logits[..., invalid_slice] = -float('Inf')
86
+
87
+ # debiased topk
88
+ # probs = F.softmax(logits, dim=-1)
89
+ # tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
90
+ # pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
91
+ # edge_idx = tk_idx[:, :, -1:]
92
+ # edge_value = tk_value[:, :, -1:]
93
+ # edge_mask = probs.gather(dim=-1, index=pred) < edge_value
94
+ # pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token
95
+ # pred.squeeze_(-1) # [batch_size, seq_length]
96
+
97
+ top_k_logits_(logits, self.topk)
98
+ probs = F.softmax(logits, dim=-1)
99
+ pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
100
+ pred.squeeze_(-1)
101
+
102
+ assert tokens.shape[1] == pred.shape[1]
103
+ tokens = pred
104
+ return tokens
105
+
106
+ def filling_sequence_itersr(
107
+ model,
108
+ seq0,
109
+ seq1,
110
+ warmup_steps=3,
111
+ block_hw=(4, 4),
112
+ strategy=IterativeEntfilterStrategy(topk=10),
113
+ ):
114
+ '''
115
+ seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
116
+ 4095 {layout[2]} final_token.
117
+ Attention:
118
+ The sampling temperature are changing, temporally we hard code them here.
119
+ The temperature in the strategy is not used.
120
+ '''
121
+ assert hasattr(model, 'layout')
122
+ layout = model.layout
123
+
124
+ device = seq0.device
125
+ # concat and pad sequences
126
+ batch_size = seq0.shape[0]
127
+ n_pad = layout[0] - seq0.shape[1]
128
+ assert n_pad >= 0, "You should truncate long input before filling."
129
+ seq = torch.cat((
130
+ torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
131
+ .unsqueeze(0).expand(batch_size, n_pad),
132
+ seq0, seq1), dim=1) # [b, layout[-1]+1]
133
+ assert seq.shape[1] == layout[-1]
134
+
135
+ # build initial tokens, attention_mask, and position_ids
136
+ tokens = seq.clone()
137
+ attention_mask = torch.ones(layout[0]).to(device)
138
+ attention_mask[:n_pad] = 0
139
+ attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16
140
+ position_ids = torch.cat((
141
+ torch.zeros(n_pad, dtype=torch.long),
142
+ torch.arange(0, layout[0] - n_pad),
143
+ torch.arange(1024, 1024+layout[1]-layout[0]))).to(device)
144
+ log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters()))
145
+ log_attention_weights[n_pad:layout[0]] = 0.
146
+ log_attention_weights = log_attention_weights.unsqueeze(0)
147
+
148
+ # prepare for interation
149
+ unfixed = (tokens == tokenizer['<start_of_image>'])
150
+ ll, rr = block_hw
151
+ edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
152
+ num_steps = 1
153
+ # interative refining
154
+
155
+ # unfixed[..., -(layout[-1] - layout[-2]):].view(
156
+ # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
157
+
158
+
159
+ ret = []
160
+ # ret.append(tokens[:, layout[-2]:-1].clone())
161
+ for step_cnt in range(1, num_steps+1):
162
+ logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights)
163
+ real_temp = 1.
164
+ new_tokens = strategy.forward(logits, tokens, real_temp)
165
+ tokens[unfixed] = new_tokens[unfixed]
166
+
167
+ ret.append(tokens[:, layout[-2]:].clone())
168
+ return torch.cat(ret, dim=0)
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/sr_group.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : sr_group.py
4
+ @Time : 2022/04/02 01:17:21
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from SwissArmyTransformer.resources import auto_create
19
+ from .direct_sr import DirectSuperResolution
20
+ from .iterative_sr import IterativeSuperResolution
21
+
22
+ class SRGroup:
23
+ def __init__(self, args, home_path=None,):
24
+ dsr_path = auto_create('cogview2-dsr', path=home_path)
25
+ itersr_path = auto_create('cogview2-itersr', path=home_path)
26
+ dsr = DirectSuperResolution(args, dsr_path)
27
+ itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
28
+ self.dsr = dsr
29
+ self.itersr = itersr
30
+
31
+ def sr_base(self, img_tokens, txt_tokens):
32
+ assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
33
+ batch_size = img_tokens.shape[0]
34
+ txt_len = txt_tokens.shape[-1]
35
+ if len(txt_tokens.shape) == 1:
36
+ txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
37
+ sred_tokens = self.dsr(txt_tokens, img_tokens)
38
+ iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
39
+ return iter_tokens[-batch_size:]
40
+
41
+ # def sr_patch(self, img_tokens, txt_tokens):
42
+ # assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2
43
+ # batch_size = img_tokens.shape[0] * 9
44
+ # txt_len = txt_tokens.shape[-1]
45
+ # if len(txt_tokens.shape) == 1:
46
+ # txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
47
+ # img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400)
48
+ # iter_tokens = self.sr_base(img_tokens, txt_tokens)
49
+ # return iter_tokens
src/videogen_hub/pipelines/consisti2v/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 TIGER Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
src/videogen_hub/pipelines/consisti2v/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/configs/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/configs/inference/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/configs/inference/inference.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "samples/inference"
2
+ output_name: "i2v"
3
+
4
+ pretrained_model_path: "TIGER-Lab/ConsistI2V"
5
+ unet_path: null
6
+ unet_ckpt_prefix: "module."
7
+ pipeline_pretrained_path: null
8
+
9
+ sampling_kwargs:
10
+ height: 256
11
+ width: 256
12
+ n_frames: 16
13
+ steps: 50
14
+ ddim_eta: 0.0
15
+ guidance_scale_txt: 7.5
16
+ guidance_scale_img: 1.0
17
+ guidance_rescale: 0.0
18
+ num_videos_per_prompt: 1
19
+ frame_stride: 3
20
+
21
+ unet_additional_kwargs:
22
+ variant: null
23
+ n_temp_heads: 8
24
+ augment_temporal_attention: true
25
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
26
+ first_frame_condition_mode: "concat"
27
+ use_frame_stride_condition: true
28
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
29
+ noise_alpha: 1.0
30
+
31
+ noise_scheduler_kwargs:
32
+ beta_start: 0.00085
33
+ beta_end: 0.012
34
+ beta_schedule: "linear"
35
+ steps_offset: 1
36
+ clip_sample: false
37
+ rescale_betas_zero_snr: false # true if using zero terminal snr
38
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
39
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
40
+
41
+ frameinit_kwargs:
42
+ enable: true
43
+ camera_motion: null
44
+ noise_level: 850
45
+ filter_params:
46
+ method: 'gaussian'
47
+ d_s: 0.25
48
+ d_t: 0.25
src/videogen_hub/pipelines/consisti2v/configs/inference/inference_autoregress.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "samples/inference"
2
+ output_name: "long_video"
3
+
4
+ pretrained_model_path: "TIGER-Lab/ConsistI2V"
5
+ unet_path: null
6
+ unet_ckpt_prefix: "module."
7
+ pipeline_pretrained_path: null
8
+
9
+ sampling_kwargs:
10
+ height: 256
11
+ width: 256
12
+ n_frames: 16
13
+ steps: 50
14
+ ddim_eta: 0.0
15
+ guidance_scale_txt: 7.5
16
+ guidance_scale_img: 1.0
17
+ guidance_rescale: 0.0
18
+ num_videos_per_prompt: 1
19
+ frame_stride: 3
20
+ autoregress_steps: 3
21
+
22
+ unet_additional_kwargs:
23
+ variant: null
24
+ n_temp_heads: 8
25
+ augment_temporal_attention: true
26
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
27
+ first_frame_condition_mode: "concat"
28
+ use_frame_stride_condition: true
29
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
30
+ noise_alpha: 1.0
31
+
32
+ noise_scheduler_kwargs:
33
+ beta_start: 0.00085
34
+ beta_end: 0.012
35
+ beta_schedule: "linear"
36
+ steps_offset: 1
37
+ clip_sample: false
38
+ rescale_betas_zero_snr: false # true if using zero terminal snr
39
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
40
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
41
+
42
+
43
+ frameinit_kwargs:
44
+ enable: true
45
+ noise_level: 850
46
+ filter_params:
47
+ method: 'gaussian'
48
+ d_s: 0.25
49
+ d_t: 0.25
src/videogen_hub/pipelines/consisti2v/configs/prompts/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/configs/prompts/default.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seeds: random
2
+
3
+ prompts:
4
+ - "timelapse at the snow land with aurora in the sky."
5
+ - "fireworks."
6
+ - "clown fish swimming through the coral reef."
7
+ - "melting ice cream dripping down the cone."
8
+
9
+ n_prompts:
10
+ - ""
11
+
12
+ path_to_first_frames:
13
+ - "assets/example/example_01.png"
14
+ - "assets/example/example_02.png"
15
+ - "assets/example/example_03.png"
16
+ - "assets/example/example_04.png"
src/videogen_hub/pipelines/consisti2v/configs/training/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/configs/training/training.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "checkpoints"
2
+ pretrained_model_path: "stabilityai/stable-diffusion-2-1-base"
3
+
4
+ noise_scheduler_kwargs:
5
+ num_train_timesteps: 1000
6
+ beta_start: 0.00085
7
+ beta_end: 0.012
8
+ beta_schedule: "linear"
9
+ steps_offset: 1
10
+ clip_sample: false
11
+ rescale_betas_zero_snr: false # true if using zero terminal snr
12
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
13
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
14
+
15
+ train_data:
16
+ dataset: "joint"
17
+ pexels_config:
18
+ enable: false
19
+ json_path: null
20
+ caption_json_path: null
21
+ video_folder: null
22
+ webvid_config:
23
+ enable: true
24
+ json_path: "/path/to/webvid/annotation"
25
+ video_folder: "/path/to/webvid/data"
26
+ sample_size: 256
27
+ sample_duration: null
28
+ sample_fps: null
29
+ sample_stride: [1, 5]
30
+ sample_n_frames: 16
31
+
32
+ validation_data:
33
+ prompts:
34
+ - "timelapse at the snow land with aurora in the sky."
35
+ - "fireworks."
36
+ - "clown fish swimming through the coral reef."
37
+ - "melting ice cream dripping down the cone."
38
+
39
+ path_to_first_frames:
40
+ - "assets/example/example_01.jpg"
41
+ - "assets/example/example_02.jpg"
42
+ - "assets/example/example_03.jpg"
43
+ - "assets/example/example_04.jpg"
44
+
45
+ num_inference_steps: 50
46
+ ddim_eta: 0.0
47
+ guidance_scale_txt: 7.5
48
+ guidance_scale_img: 1.0
49
+ guidance_rescale: 0.0
50
+ frame_stride: 3
51
+
52
+ trainable_modules:
53
+ - "all"
54
+ # - "conv3ds."
55
+ # - "tempo_attns."
56
+
57
+ resume_from_checkpoint: null
58
+
59
+ unet_additional_kwargs:
60
+ variant: null
61
+ n_temp_heads: 8
62
+ augment_temporal_attention: true
63
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
64
+ first_frame_condition_mode: "concat"
65
+ use_frame_stride_condition: true
66
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
67
+ noise_alpha: 1.0
68
+
69
+ cfg_random_null_text_ratio: 0.1
70
+ cfg_random_null_img_ratio: 0.1
71
+
72
+ use_ema: false
73
+ ema_decay: 0.9999
74
+
75
+ learning_rate: 5.e-5
76
+ train_batch_size: 3
77
+ gradient_accumulation_steps: 1
78
+ max_grad_norm: 0.5
79
+
80
+ max_train_epoch: -1
81
+ max_train_steps: 200000
82
+ checkpointing_epochs: -1
83
+ checkpointing_steps: 2000
84
+ validation_steps: 1000
85
+
86
+ seed: 42
87
+ mixed_precision: "bf16"
88
+ num_workers: 32
89
+ enable_xformers_memory_efficient_attention: true
90
+
91
+ is_image: false
92
+ is_debug: false
src/videogen_hub/pipelines/consisti2v/consisti2v/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/consisti2v/data/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/consisti2v/data/dataset.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random
2
+ import json
3
+ import numpy as np
4
+ from einops import rearrange
5
+ from decord import VideoReader
6
+
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+ from diffusers.utils import logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ class WebVid10M(Dataset):
16
+ def __init__(
17
+ self,
18
+ json_path, video_folder=None,
19
+ sample_size=256, sample_stride=4, sample_n_frames=16,
20
+ is_image=False,
21
+ **kwargs,
22
+ ):
23
+ logger.info(f"loading annotations from {json_path} ...")
24
+ with open(json_path, 'rb') as json_file:
25
+ json_list = list(json_file)
26
+ self.dataset = [json.loads(json_str) for json_str in json_list]
27
+ self.length = len(self.dataset)
28
+ logger.info(f"data scale: {self.length}")
29
+
30
+ self.video_folder = video_folder
31
+ self.sample_stride = sample_stride if isinstance(sample_stride, int) else tuple(sample_stride)
32
+ self.sample_n_frames = sample_n_frames
33
+ self.is_image = is_image
34
+
35
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
36
+ self.pixel_transforms = transforms.Compose([
37
+ transforms.RandomHorizontalFlip(),
38
+ transforms.Resize(sample_size[0], antialias=None),
39
+ transforms.CenterCrop(sample_size),
40
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
41
+ ])
42
+
43
+ def get_batch(self, idx):
44
+ video_dict = self.dataset[idx]
45
+ video_relative_path, name = video_dict['file'], video_dict['text']
46
+
47
+ if self.video_folder is not None:
48
+ if video_relative_path[0] == '/':
49
+ video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
50
+ else:
51
+ video_dir = os.path.join(self.video_folder, video_relative_path)
52
+ else:
53
+ video_dir = video_relative_path
54
+ video_reader = VideoReader(video_dir)
55
+ video_length = len(video_reader)
56
+
57
+ if not self.is_image:
58
+ if isinstance(self.sample_stride, int):
59
+ stride = self.sample_stride
60
+ elif isinstance(self.sample_stride, tuple):
61
+ stride = random.randint(self.sample_stride[0], self.sample_stride[1])
62
+ clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
63
+ start_idx = random.randint(0, video_length - clip_length)
64
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
65
+ else:
66
+ frame_difference = random.randint(2, self.sample_n_frames)
67
+ clip_length = min(video_length, (frame_difference - 1) * self.sample_stride + 1)
68
+ start_idx = random.randint(0, video_length - clip_length)
69
+ batch_index = [start_idx, start_idx + clip_length - 1]
70
+
71
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
72
+ pixel_values = pixel_values / 255.
73
+ del video_reader
74
+
75
+ return pixel_values, name
76
+
77
+ def __len__(self):
78
+ return self.length
79
+
80
+ def __getitem__(self, idx):
81
+ while True:
82
+ try:
83
+ pixel_values, name = self.get_batch(idx)
84
+ break
85
+
86
+ except Exception as e:
87
+ idx = random.randint(0, self.length-1)
88
+
89
+ pixel_values = self.pixel_transforms(pixel_values)
90
+ sample = dict(pixel_values=pixel_values, text=name)
91
+ return sample
92
+
93
+
94
+ class Pexels(Dataset):
95
+ def __init__(
96
+ self,
97
+ json_path, caption_json_path, video_folder=None,
98
+ sample_size=256, sample_duration=1, sample_fps=8,
99
+ is_image=False,
100
+ **kwargs,
101
+ ):
102
+ logger.info(f"loading captions from {caption_json_path} ...")
103
+ with open(caption_json_path, 'rb') as caption_json_file:
104
+ caption_json_list = list(caption_json_file)
105
+ self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
106
+
107
+ logger.info(f"loading annotations from {json_path} ...")
108
+ with open(json_path, 'rb') as json_file:
109
+ json_list = list(json_file)
110
+ dataset = [json.loads(json_str) for json_str in json_list]
111
+
112
+ self.dataset = []
113
+ for data in dataset:
114
+ data['text'] = self.caption_dict[data['id']]
115
+ if data['height'] / data['width'] < 0.625:
116
+ self.dataset.append(data)
117
+ self.length = len(self.dataset)
118
+ logger.info(f"data scale: {self.length}")
119
+
120
+ self.video_folder = video_folder
121
+ self.sample_duration = sample_duration
122
+ self.sample_fps = sample_fps
123
+ self.sample_n_frames = sample_duration * sample_fps
124
+ self.is_image = is_image
125
+
126
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
127
+ self.pixel_transforms = transforms.Compose([
128
+ transforms.RandomHorizontalFlip(),
129
+ transforms.Resize(sample_size[0], antialias=None),
130
+ transforms.CenterCrop(sample_size),
131
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
132
+ ])
133
+
134
+ def get_batch(self, idx):
135
+ video_dict = self.dataset[idx]
136
+ video_relative_path, name = video_dict['file'], video_dict['text']
137
+ fps = video_dict['fps']
138
+
139
+ if self.video_folder is not None:
140
+ if video_relative_path[0] == '/':
141
+ video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
142
+ else:
143
+ video_dir = os.path.join(self.video_folder, video_relative_path)
144
+ else:
145
+ video_dir = video_relative_path
146
+ video_reader = VideoReader(video_dir)
147
+ video_length = len(video_reader)
148
+
149
+ if not self.is_image:
150
+ clip_length = min(video_length, math.ceil(fps * self.sample_duration))
151
+ start_idx = random.randint(0, video_length - clip_length)
152
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
153
+ else:
154
+ frame_difference = random.randint(2, self.sample_n_frames)
155
+ sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
156
+ clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
157
+ start_idx = random.randint(0, video_length - clip_length)
158
+ batch_index = [start_idx, start_idx + clip_length - 1]
159
+
160
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
161
+ pixel_values = pixel_values / 255.
162
+ del video_reader
163
+
164
+ return pixel_values, name
165
+
166
+ def __len__(self):
167
+ return self.length
168
+
169
+ def __getitem__(self, idx):
170
+ while True:
171
+ try:
172
+ pixel_values, name = self.get_batch(idx)
173
+ break
174
+
175
+ except Exception as e:
176
+ idx = random.randint(0, self.length-1)
177
+
178
+ pixel_values = self.pixel_transforms(pixel_values)
179
+ sample = dict(pixel_values=pixel_values, text=name)
180
+ return sample
181
+
182
+
183
+ class JointDataset(Dataset):
184
+ def __init__(
185
+ self,
186
+ webvid_config, pexels_config,
187
+ sample_size=256,
188
+ sample_duration=None, sample_fps=None, sample_stride=None, sample_n_frames=None,
189
+ is_image=False,
190
+ **kwargs,
191
+ ):
192
+ assert (sample_duration is None and sample_fps is None) or (sample_duration is not None and sample_fps is not None), "sample_duration and sample_fps should be both None or not None"
193
+ if sample_duration is not None and sample_fps is not None:
194
+ assert sample_stride is None, "when sample_duration and sample_fps are not None, sample_stride should be None"
195
+ if sample_stride is not None:
196
+ assert sample_fps is None and sample_duration is None, "when sample_stride is not None, sample_duration and sample_fps should be both None"
197
+
198
+ self.dataset = []
199
+
200
+ if pexels_config.enable:
201
+ logger.info(f"loading pexels dataset")
202
+ logger.info(f"loading captions from {pexels_config.caption_json_path} ...")
203
+ with open(pexels_config.caption_json_path, 'rb') as caption_json_file:
204
+ caption_json_list = list(caption_json_file)
205
+ self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
206
+
207
+ logger.info(f"loading annotations from {pexels_config.json_path} ...")
208
+ with open(pexels_config.json_path, 'rb') as json_file:
209
+ json_list = list(json_file)
210
+ dataset = [json.loads(json_str) for json_str in json_list]
211
+
212
+ for data in dataset:
213
+ data['text'] = self.caption_dict[data['id']]
214
+ data['dataset'] = 'pexels'
215
+ if data['height'] / data['width'] < 0.625:
216
+ self.dataset.append(data)
217
+
218
+ if webvid_config.enable:
219
+ logger.info(f"loading webvid dataset")
220
+ logger.info(f"loading annotations from {webvid_config.json_path} ...")
221
+ with open(webvid_config.json_path, 'rb') as json_file:
222
+ json_list = list(json_file)
223
+ dataset = [json.loads(json_str) for json_str in json_list]
224
+ for data in dataset:
225
+ data['dataset'] = 'webvid'
226
+ self.dataset.extend(dataset)
227
+
228
+ self.length = len(self.dataset)
229
+ logger.info(f"data scale: {self.length}")
230
+
231
+ self.pexels_folder = pexels_config.video_folder
232
+ self.webvid_folder = webvid_config.video_folder
233
+ self.sample_duration = sample_duration
234
+ self.sample_fps = sample_fps
235
+ self.sample_n_frames = sample_duration * sample_fps if sample_n_frames is None else sample_n_frames
236
+ self.sample_stride = sample_stride if (sample_stride is None) or (sample_stride is not None and isinstance(sample_stride, int)) else tuple(sample_stride)
237
+ self.is_image = is_image
238
+
239
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
240
+ self.pixel_transforms = transforms.Compose([
241
+ transforms.RandomHorizontalFlip(),
242
+ transforms.Resize(sample_size[0], antialias=None),
243
+ transforms.CenterCrop(sample_size),
244
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
245
+ ])
246
+
247
+ def get_batch(self, idx):
248
+ video_dict = self.dataset[idx]
249
+ video_relative_path, name = video_dict['file'], video_dict['text']
250
+
251
+ if video_dict['dataset'] == 'pexels':
252
+ video_folder = self.pexels_folder
253
+ elif video_dict['dataset'] == 'webvid':
254
+ video_folder = self.webvid_folder
255
+ else:
256
+ raise NotImplementedError
257
+
258
+ if video_folder is not None:
259
+ if video_relative_path[0] == '/':
260
+ video_dir = os.path.join(video_folder, os.path.basename(video_relative_path))
261
+ else:
262
+ video_dir = os.path.join(video_folder, video_relative_path)
263
+ else:
264
+ video_dir = video_relative_path
265
+ video_reader = VideoReader(video_dir)
266
+ video_length = len(video_reader)
267
+
268
+ stride = None
269
+ if not self.is_image:
270
+ if self.sample_duration is not None:
271
+ fps = video_dict['fps']
272
+ clip_length = min(video_length, math.ceil(fps * self.sample_duration))
273
+ elif self.sample_stride is not None:
274
+ if isinstance(self.sample_stride, int):
275
+ stride = self.sample_stride
276
+ elif isinstance(self.sample_stride, tuple):
277
+ stride = random.randint(self.sample_stride[0], self.sample_stride[1])
278
+ clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
279
+
280
+ start_idx = random.randint(0, video_length - clip_length)
281
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
282
+
283
+ else:
284
+ frame_difference = random.randint(2, self.sample_n_frames)
285
+ if self.sample_duration is not None:
286
+ fps = video_dict['fps']
287
+ sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
288
+ elif self.sample_stride is not None:
289
+ sample_stride = self.sample_stride
290
+
291
+ clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
292
+ start_idx = random.randint(0, video_length - clip_length)
293
+ batch_index = [start_idx, start_idx + clip_length - 1]
294
+
295
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
296
+ pixel_values = pixel_values / 255.
297
+ del video_reader
298
+
299
+ return pixel_values, name, stride
300
+
301
+ def __len__(self):
302
+ return self.length
303
+
304
+ def __getitem__(self, idx):
305
+ while True:
306
+ try:
307
+ pixel_values, name, stride = self.get_batch(idx)
308
+ break
309
+
310
+ except Exception as e:
311
+ idx = random.randint(0, self.length-1)
312
+
313
+ pixel_values = self.pixel_transforms(pixel_values)
314
+ sample = dict(pixel_values=pixel_values, text=name, stride=stride)
315
+ return sample
src/videogen_hub/pipelines/consisti2v/consisti2v/models/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/consisti2v/models/rotary_embedding.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi, log
2
+
3
+ import torch
4
+ from torch.nn import Module, ModuleList
5
+ from torch.cuda.amp import autocast
6
+ from torch import nn, einsum, broadcast_tensors, Tensor
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ from beartype import beartype
11
+ from beartype.typing import Literal, Union, Optional
12
+
13
+ # helper functions
14
+
15
+ def exists(val):
16
+ return val is not None
17
+
18
+ def default(val, d):
19
+ return val if exists(val) else d
20
+
21
+ # broadcat, as tortoise-tts was using it
22
+
23
+ def broadcat(tensors, dim = -1):
24
+ broadcasted_tensors = broadcast_tensors(*tensors)
25
+ return torch.cat(broadcasted_tensors, dim = dim)
26
+
27
+ # rotary embedding helper functions
28
+
29
+ def rotate_half(x):
30
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
31
+ x1, x2 = x.unbind(dim = -1)
32
+ x = torch.stack((-x2, x1), dim = -1)
33
+ return rearrange(x, '... d r -> ... (d r)')
34
+
35
+ @autocast(enabled = False)
36
+ def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
37
+ if t.ndim == 3:
38
+ seq_len = t.shape[seq_dim]
39
+ freqs = freqs[-seq_len:].to(t)
40
+
41
+ rot_dim = freqs.shape[-1]
42
+ end_index = start_index + rot_dim
43
+
44
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
45
+
46
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
47
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
48
+ return torch.cat((t_left, t, t_right), dim = -1)
49
+
50
+ # learned rotation helpers
51
+
52
+ def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
53
+ if exists(freq_ranges):
54
+ rotations = einsum('..., f -> ... f', rotations, freq_ranges)
55
+ rotations = rearrange(rotations, '... r f -> ... (r f)')
56
+
57
+ rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
58
+ return apply_rotary_emb(rotations, t, start_index = start_index)
59
+
60
+ # classes
61
+
62
+ class RotaryEmbedding(Module):
63
+ @beartype
64
+ def __init__(
65
+ self,
66
+ dim,
67
+ custom_freqs: Optional[Tensor] = None,
68
+ freqs_for: Union[
69
+ Literal['lang'],
70
+ Literal['pixel'],
71
+ Literal['constant']
72
+ ] = 'lang',
73
+ theta = 10000,
74
+ max_freq = 10,
75
+ num_freqs = 1,
76
+ learned_freq = False,
77
+ use_xpos = False,
78
+ xpos_scale_base = 512,
79
+ interpolate_factor = 1.,
80
+ theta_rescale_factor = 1.,
81
+ seq_before_head_dim = False
82
+ ):
83
+ super().__init__()
84
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
85
+ # has some connection to NTK literature
86
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
87
+
88
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
89
+
90
+ self.freqs_for = freqs_for
91
+
92
+ if exists(custom_freqs):
93
+ freqs = custom_freqs
94
+ elif freqs_for == 'lang':
95
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
96
+ elif freqs_for == 'pixel':
97
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
98
+ elif freqs_for == 'constant':
99
+ freqs = torch.ones(num_freqs).float()
100
+
101
+ self.tmp_store('cached_freqs', None)
102
+ self.tmp_store('cached_scales', None)
103
+
104
+ self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
105
+
106
+ self.learned_freq = learned_freq
107
+
108
+ # dummy for device
109
+
110
+ self.tmp_store('dummy', torch.tensor(0))
111
+
112
+ # default sequence dimension
113
+
114
+ self.seq_before_head_dim = seq_before_head_dim
115
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
116
+
117
+ # interpolation factors
118
+
119
+ assert interpolate_factor >= 1.
120
+ self.interpolate_factor = interpolate_factor
121
+
122
+ # xpos
123
+
124
+ self.use_xpos = use_xpos
125
+ if not use_xpos:
126
+ self.tmp_store('scale', None)
127
+ return
128
+
129
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
130
+ self.scale_base = xpos_scale_base
131
+ self.tmp_store('scale', scale)
132
+
133
+ @property
134
+ def device(self):
135
+ return self.dummy.device
136
+
137
+ def tmp_store(self, key, value):
138
+ self.register_buffer(key, value, persistent = False)
139
+
140
+ def get_seq_pos(self, seq_len, device, dtype, offset = 0):
141
+ return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
142
+
143
+ def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None, seq_pos = None):
144
+ seq_dim = default(seq_dim, self.default_seq_dim)
145
+
146
+ assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
147
+
148
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
149
+
150
+ if exists(freq_seq_len):
151
+ assert freq_seq_len >= seq_len
152
+ seq_len = freq_seq_len
153
+
154
+ if seq_pos is None:
155
+ seq_pos = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset)
156
+ else:
157
+ assert seq_pos.shape[0] == seq_len
158
+
159
+ freqs = self.forward(seq_pos, seq_len = seq_len, offset = offset)
160
+
161
+ if seq_dim == -3:
162
+ freqs = rearrange(freqs, 'n d -> n 1 d')
163
+
164
+ return apply_rotary_emb(freqs, t, seq_dim = seq_dim)
165
+
166
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
167
+ seq_dim = default(seq_dim, self.default_seq_dim)
168
+
169
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
170
+ assert q_len <= k_len
171
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len)
172
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim)
173
+
174
+ rotated_q = rotated_q.type(q.dtype)
175
+ rotated_k = rotated_k.type(k.dtype)
176
+
177
+ return rotated_q, rotated_k
178
+
179
+ def rotate_queries_and_keys(self, q, k, seq_dim = None):
180
+ seq_dim = default(seq_dim, self.default_seq_dim)
181
+
182
+ assert self.use_xpos
183
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
184
+
185
+ seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
186
+
187
+ freqs = self.forward(seq, seq_len = seq_len)
188
+ scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
189
+
190
+ if seq_dim == -3:
191
+ freqs = rearrange(freqs, 'n d -> n 1 d')
192
+ scale = rearrange(scale, 'n d -> n 1 d')
193
+
194
+ rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
195
+ rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
196
+
197
+ rotated_q = rotated_q.type(q.dtype)
198
+ rotated_k = rotated_k.type(k.dtype)
199
+
200
+ return rotated_q, rotated_k
201
+
202
+ @beartype
203
+ def get_scale(
204
+ self,
205
+ t: Tensor,
206
+ seq_len: Optional[int] = None,
207
+ offset = 0
208
+ ):
209
+ assert self.use_xpos
210
+
211
+ should_cache = exists(seq_len)
212
+
213
+ if (
214
+ should_cache and \
215
+ exists(self.cached_scales) and \
216
+ (seq_len + offset) <= self.cached_scales.shape[0]
217
+ ):
218
+ return self.cached_scales[offset:(offset + seq_len)]
219
+
220
+ scale = 1.
221
+ if self.use_xpos:
222
+ power = (t - len(t) // 2) / self.scale_base
223
+ scale = self.scale ** rearrange(power, 'n -> n 1')
224
+ scale = torch.cat((scale, scale), dim = -1)
225
+
226
+ if should_cache:
227
+ self.tmp_store('cached_scales', scale)
228
+
229
+ return scale
230
+
231
+ def get_axial_freqs(self, *dims):
232
+ Colon = slice(None)
233
+ all_freqs = []
234
+
235
+ for ind, dim in enumerate(dims):
236
+ if self.freqs_for == 'pixel':
237
+ pos = torch.linspace(-1, 1, steps = dim, device = self.device)
238
+ else:
239
+ pos = torch.arange(dim, device = self.device)
240
+
241
+ freqs = self.forward(pos, seq_len = dim)
242
+
243
+ all_axis = [None] * len(dims)
244
+ all_axis[ind] = Colon
245
+
246
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
247
+ all_freqs.append(freqs[new_axis_slice])
248
+
249
+ all_freqs = broadcast_tensors(*all_freqs)
250
+ return torch.cat(all_freqs, dim = -1)
251
+
252
+ @autocast(enabled = False)
253
+ def forward(
254
+ self,
255
+ t: Tensor,
256
+ seq_len = None,
257
+ offset = 0
258
+ ):
259
+ # should_cache = (
260
+ # not self.learned_freq and \
261
+ # exists(seq_len) and \
262
+ # self.freqs_for != 'pixel'
263
+ # )
264
+
265
+ # if (
266
+ # should_cache and \
267
+ # exists(self.cached_freqs) and \
268
+ # (offset + seq_len) <= self.cached_freqs.shape[0]
269
+ # ):
270
+ # return self.cached_freqs[offset:(offset + seq_len)].detach()
271
+
272
+ freqs = self.freqs
273
+
274
+ freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
275
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
276
+
277
+ # if should_cache:
278
+ # self.tmp_store('cached_freqs', freqs.detach())
279
+
280
+ return freqs
src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_attention.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from typing import Callable, Optional, Union
3
+ import math
4
+
5
+ from einops import rearrange, repeat
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from diffusers.utils import deprecate, logging
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
14
+ from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
15
+ from diffusers.models.attention_processor import (
16
+ Attention,
17
+ AttnAddedKVProcessor,
18
+ AttnAddedKVProcessor2_0,
19
+ AttnProcessor,
20
+ AttnProcessor2_0,
21
+ SpatialNorm,
22
+ LORA_ATTENTION_PROCESSORS,
23
+ CustomDiffusionAttnProcessor,
24
+ CustomDiffusionXFormersAttnProcessor,
25
+ SlicedAttnAddedKVProcessor,
26
+ XFormersAttnAddedKVProcessor,
27
+ LoRAAttnAddedKVProcessor,
28
+ XFormersAttnProcessor,
29
+ LoRAXFormersAttnProcessor,
30
+ LoRAAttnProcessor,
31
+ LoRAAttnProcessor2_0,
32
+ SlicedAttnProcessor,
33
+ AttentionProcessor
34
+ )
35
+
36
+ from .rotary_embedding import RotaryEmbedding
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ if is_xformers_available():
43
+ import xformers
44
+ import xformers.ops
45
+ else:
46
+ xformers = None
47
+
48
+ @maybe_allow_in_graph
49
+ class ConditionalAttention(nn.Module):
50
+ r"""
51
+ A cross attention layer.
52
+
53
+ Parameters:
54
+ query_dim (`int`): The number of channels in the query.
55
+ cross_attention_dim (`int`, *optional*):
56
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
57
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
58
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
59
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
60
+ bias (`bool`, *optional*, defaults to False):
61
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ query_dim: int,
67
+ cross_attention_dim: Optional[int] = None,
68
+ heads: int = 8,
69
+ dim_head: int = 64,
70
+ dropout: float = 0.0,
71
+ bias=False,
72
+ upcast_attention: bool = False,
73
+ upcast_softmax: bool = False,
74
+ cross_attention_norm: Optional[str] = None,
75
+ cross_attention_norm_num_groups: int = 32,
76
+ added_kv_proj_dim: Optional[int] = None,
77
+ norm_num_groups: Optional[int] = None,
78
+ spatial_norm_dim: Optional[int] = None,
79
+ out_bias: bool = True,
80
+ scale_qk: bool = True,
81
+ only_cross_attention: bool = False,
82
+ eps: float = 1e-5,
83
+ rescale_output_factor: float = 1.0,
84
+ residual_connection: bool = False,
85
+ _from_deprecated_attn_block=False,
86
+ processor: Optional["AttnProcessor"] = None,
87
+ ):
88
+ super().__init__()
89
+ self.inner_dim = dim_head * heads
90
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
91
+ self.upcast_attention = upcast_attention
92
+ self.upcast_softmax = upcast_softmax
93
+ self.rescale_output_factor = rescale_output_factor
94
+ self.residual_connection = residual_connection
95
+ self.dropout = dropout
96
+
97
+ # we make use of this private variable to know whether this class is loaded
98
+ # with an deprecated state dict so that we can convert it on the fly
99
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
100
+
101
+ self.scale_qk = scale_qk
102
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
103
+
104
+ self.heads = heads
105
+ # for slice_size > 0 the attention score computation
106
+ # is split across the batch axis to save memory
107
+ # You can set slice_size with `set_attention_slice`
108
+ self.sliceable_head_dim = heads
109
+
110
+ self.added_kv_proj_dim = added_kv_proj_dim
111
+ self.only_cross_attention = only_cross_attention
112
+
113
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
114
+ raise ValueError(
115
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
116
+ )
117
+
118
+ if norm_num_groups is not None:
119
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
120
+ else:
121
+ self.group_norm = None
122
+
123
+ if spatial_norm_dim is not None:
124
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
125
+ else:
126
+ self.spatial_norm = None
127
+
128
+ if cross_attention_norm is None:
129
+ self.norm_cross = None
130
+ elif cross_attention_norm == "layer_norm":
131
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
132
+ elif cross_attention_norm == "group_norm":
133
+ if self.added_kv_proj_dim is not None:
134
+ # The given `encoder_hidden_states` are initially of shape
135
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
136
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
137
+ # before the projection, so we need to use `added_kv_proj_dim` as
138
+ # the number of channels for the group norm.
139
+ norm_cross_num_channels = added_kv_proj_dim
140
+ else:
141
+ norm_cross_num_channels = self.cross_attention_dim
142
+
143
+ self.norm_cross = nn.GroupNorm(
144
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
145
+ )
146
+ else:
147
+ raise ValueError(
148
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
149
+ )
150
+
151
+ self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias)
152
+
153
+ if not self.only_cross_attention:
154
+ # only relevant for the `AddedKVProcessor` classes
155
+ self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
156
+ self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
157
+ else:
158
+ self.to_k = None
159
+ self.to_v = None
160
+
161
+ if self.added_kv_proj_dim is not None:
162
+ self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
163
+ self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
164
+
165
+ self.to_out = nn.ModuleList([])
166
+ self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias))
167
+ self.to_out.append(nn.Dropout(dropout))
168
+
169
+ # set attention processor
170
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
171
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
172
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
173
+ if processor is None:
174
+ processor = (
175
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
176
+ )
177
+ self.set_processor(processor)
178
+
179
+ def set_use_memory_efficient_attention_xformers(
180
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
181
+ ):
182
+ is_lora = hasattr(self, "processor") and isinstance(
183
+ self.processor,
184
+ LORA_ATTENTION_PROCESSORS,
185
+ )
186
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
187
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
188
+ )
189
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
190
+ self.processor,
191
+ (
192
+ AttnAddedKVProcessor,
193
+ AttnAddedKVProcessor2_0,
194
+ SlicedAttnAddedKVProcessor,
195
+ XFormersAttnAddedKVProcessor,
196
+ LoRAAttnAddedKVProcessor,
197
+ ),
198
+ )
199
+
200
+ if use_memory_efficient_attention_xformers:
201
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
202
+ raise NotImplementedError(
203
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
204
+ )
205
+ if not is_xformers_available():
206
+ raise ModuleNotFoundError(
207
+ (
208
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
209
+ " xformers"
210
+ ),
211
+ name="xformers",
212
+ )
213
+ elif not torch.cuda.is_available():
214
+ raise ValueError(
215
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
216
+ " only available for GPU "
217
+ )
218
+ else:
219
+ try:
220
+ # Make sure we can run the memory efficient attention
221
+ _ = xformers.ops.memory_efficient_attention(
222
+ torch.randn((1, 2, 40), device="cuda"),
223
+ torch.randn((1, 2, 40), device="cuda"),
224
+ torch.randn((1, 2, 40), device="cuda"),
225
+ )
226
+ except Exception as e:
227
+ raise e
228
+
229
+ if is_lora:
230
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
231
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
232
+ processor = LoRAXFormersAttnProcessor(
233
+ hidden_size=self.processor.hidden_size,
234
+ cross_attention_dim=self.processor.cross_attention_dim,
235
+ rank=self.processor.rank,
236
+ attention_op=attention_op,
237
+ )
238
+ processor.load_state_dict(self.processor.state_dict())
239
+ processor.to(self.processor.to_q_lora.up.weight.device)
240
+ elif is_custom_diffusion:
241
+ processor = CustomDiffusionXFormersAttnProcessor(
242
+ train_kv=self.processor.train_kv,
243
+ train_q_out=self.processor.train_q_out,
244
+ hidden_size=self.processor.hidden_size,
245
+ cross_attention_dim=self.processor.cross_attention_dim,
246
+ attention_op=attention_op,
247
+ )
248
+ processor.load_state_dict(self.processor.state_dict())
249
+ if hasattr(self.processor, "to_k_custom_diffusion"):
250
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
251
+ elif is_added_kv_processor:
252
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
253
+ # which uses this type of cross attention ONLY because the attention mask of format
254
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
255
+ # throw warning
256
+ logger.info(
257
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
258
+ )
259
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
260
+ else:
261
+ processor = XFormersAttnProcessor(attention_op=attention_op)
262
+ else:
263
+ if is_lora:
264
+ attn_processor_class = (
265
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
266
+ )
267
+ processor = attn_processor_class(
268
+ hidden_size=self.processor.hidden_size,
269
+ cross_attention_dim=self.processor.cross_attention_dim,
270
+ rank=self.processor.rank,
271
+ )
272
+ processor.load_state_dict(self.processor.state_dict())
273
+ processor.to(self.processor.to_q_lora.up.weight.device)
274
+ elif is_custom_diffusion:
275
+ processor = CustomDiffusionAttnProcessor(
276
+ train_kv=self.processor.train_kv,
277
+ train_q_out=self.processor.train_q_out,
278
+ hidden_size=self.processor.hidden_size,
279
+ cross_attention_dim=self.processor.cross_attention_dim,
280
+ )
281
+ processor.load_state_dict(self.processor.state_dict())
282
+ if hasattr(self.processor, "to_k_custom_diffusion"):
283
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
284
+ else:
285
+ # set attention processor
286
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
287
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
288
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
289
+ processor = (
290
+ AttnProcessor2_0()
291
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
292
+ else AttnProcessor()
293
+ )
294
+
295
+ self.set_processor(processor)
296
+
297
+ def set_attention_slice(self, slice_size):
298
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
299
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
300
+
301
+ if slice_size is not None and self.added_kv_proj_dim is not None:
302
+ processor = SlicedAttnAddedKVProcessor(slice_size)
303
+ elif slice_size is not None:
304
+ processor = SlicedAttnProcessor(slice_size)
305
+ elif self.added_kv_proj_dim is not None:
306
+ processor = AttnAddedKVProcessor()
307
+ else:
308
+ # set attention processor
309
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
310
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
311
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
312
+ processor = (
313
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
314
+ )
315
+
316
+ self.set_processor(processor)
317
+
318
+ def set_processor(self, processor: "AttnProcessor"):
319
+ if (
320
+ hasattr(self, "processor")
321
+ and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
322
+ and self.to_q.lora_layer is not None
323
+ ):
324
+ deprecate(
325
+ "set_processor to offload LoRA",
326
+ "0.26.0",
327
+ "In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
328
+ )
329
+ # (Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
330
+ # We need to remove all LoRA layers
331
+ for module in self.modules():
332
+ if hasattr(module, "set_lora_layer"):
333
+ module.set_lora_layer(None)
334
+
335
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
336
+ # pop `processor` from `self._modules`
337
+ if (
338
+ hasattr(self, "processor")
339
+ and isinstance(self.processor, torch.nn.Module)
340
+ and not isinstance(processor, torch.nn.Module)
341
+ ):
342
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
343
+ self._modules.pop("processor")
344
+
345
+ self.processor = processor
346
+
347
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
348
+ if not return_deprecated_lora:
349
+ return self.processor
350
+
351
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
352
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
353
+ # with PEFT is completed.
354
+ is_lora_activated = {
355
+ name: module.lora_layer is not None
356
+ for name, module in self.named_modules()
357
+ if hasattr(module, "lora_layer")
358
+ }
359
+
360
+ # 1. if no layer has a LoRA activated we can return the processor as usual
361
+ if not any(is_lora_activated.values()):
362
+ return self.processor
363
+
364
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
365
+ is_lora_activated.pop("add_k_proj", None)
366
+ is_lora_activated.pop("add_v_proj", None)
367
+ # 2. else it is not posssible that only some layers have LoRA activated
368
+ if not all(is_lora_activated.values()):
369
+ raise ValueError(
370
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
371
+ )
372
+
373
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
374
+ non_lora_processor_cls_name = self.processor.__class__.__name__
375
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
376
+
377
+ hidden_size = self.inner_dim
378
+
379
+ # now create a LoRA attention processor from the LoRA layers
380
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
381
+ kwargs = {
382
+ "cross_attention_dim": self.cross_attention_dim,
383
+ "rank": self.to_q.lora_layer.rank,
384
+ "network_alpha": self.to_q.lora_layer.network_alpha,
385
+ "q_rank": self.to_q.lora_layer.rank,
386
+ "q_hidden_size": self.to_q.lora_layer.out_features,
387
+ "k_rank": self.to_k.lora_layer.rank,
388
+ "k_hidden_size": self.to_k.lora_layer.out_features,
389
+ "v_rank": self.to_v.lora_layer.rank,
390
+ "v_hidden_size": self.to_v.lora_layer.out_features,
391
+ "out_rank": self.to_out[0].lora_layer.rank,
392
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
393
+ }
394
+
395
+ if hasattr(self.processor, "attention_op"):
396
+ kwargs["attention_op"] = self.prcoessor.attention_op
397
+
398
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
399
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
400
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
401
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
402
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
403
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
404
+ lora_processor = lora_processor_cls(
405
+ hidden_size,
406
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
407
+ rank=self.to_q.lora_layer.rank,
408
+ network_alpha=self.to_q.lora_layer.network_alpha,
409
+ )
410
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
411
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
412
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
413
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
414
+
415
+ # only save if used
416
+ if self.add_k_proj.lora_layer is not None:
417
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
418
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
419
+ else:
420
+ lora_processor.add_k_proj_lora = None
421
+ lora_processor.add_v_proj_lora = None
422
+ else:
423
+ raise ValueError(f"{lora_processor_cls} does not exist.")
424
+
425
+ return lora_processor
426
+
427
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
428
+ # The `Attention` class can call different attention processors / attention functions
429
+ # here we simply pass along all tensors to the selected processor class
430
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
431
+ return self.processor(
432
+ self,
433
+ hidden_states,
434
+ encoder_hidden_states=encoder_hidden_states,
435
+ attention_mask=attention_mask,
436
+ **cross_attention_kwargs,
437
+ )
438
+
439
+ def batch_to_head_dim(self, tensor):
440
+ head_size = self.heads
441
+ batch_size, seq_len, dim = tensor.shape
442
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
443
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
444
+ return tensor
445
+
446
+ def head_to_batch_dim(self, tensor, out_dim=3):
447
+ head_size = self.heads
448
+ batch_size, seq_len, dim = tensor.shape
449
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
450
+ tensor = tensor.permute(0, 2, 1, 3)
451
+
452
+ if out_dim == 3:
453
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
454
+
455
+ return tensor
456
+
457
+ def get_attention_scores(self, query, key, attention_mask=None):
458
+ dtype = query.dtype
459
+ if self.upcast_attention:
460
+ query = query.float()
461
+ key = key.float()
462
+
463
+ if attention_mask is None:
464
+ baddbmm_input = torch.empty(
465
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
466
+ )
467
+ beta = 0
468
+ else:
469
+ baddbmm_input = attention_mask
470
+ beta = 1
471
+
472
+ attention_scores = torch.baddbmm(
473
+ baddbmm_input,
474
+ query,
475
+ key.transpose(-1, -2),
476
+ beta=beta,
477
+ alpha=self.scale,
478
+ )
479
+ del baddbmm_input
480
+
481
+ if self.upcast_softmax:
482
+ attention_scores = attention_scores.float()
483
+
484
+ attention_probs = attention_scores.softmax(dim=-1)
485
+ del attention_scores
486
+
487
+ attention_probs = attention_probs.to(dtype)
488
+
489
+ return attention_probs
490
+
491
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
492
+ if batch_size is None:
493
+ deprecate(
494
+ "batch_size=None",
495
+ "0.22.0",
496
+ (
497
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
498
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
499
+ " `prepare_attention_mask` when preparing the attention_mask."
500
+ ),
501
+ )
502
+ batch_size = 1
503
+
504
+ head_size = self.heads
505
+ if attention_mask is None:
506
+ return attention_mask
507
+
508
+ current_length: int = attention_mask.shape[-1]
509
+ if current_length != target_length:
510
+ if attention_mask.device.type == "mps":
511
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
512
+ # Instead, we can manually construct the padding tensor.
513
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
514
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
515
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
516
+ else:
517
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
518
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
519
+ # remaining_length: int = target_length - current_length
520
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
521
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
522
+
523
+ if out_dim == 3:
524
+ if attention_mask.shape[0] < batch_size * head_size:
525
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
526
+ elif out_dim == 4:
527
+ attention_mask = attention_mask.unsqueeze(1)
528
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
529
+
530
+ return attention_mask
531
+
532
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
533
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
534
+
535
+ if isinstance(self.norm_cross, nn.LayerNorm):
536
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
537
+ elif isinstance(self.norm_cross, nn.GroupNorm):
538
+ # Group norm norms along the channels dimension and expects
539
+ # input to be in the shape of (N, C, *). In this case, we want
540
+ # to norm along the hidden dimension, so we need to move
541
+ # (batch_size, sequence_length, hidden_size) ->
542
+ # (batch_size, hidden_size, sequence_length)
543
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
544
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
545
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
546
+ else:
547
+ assert False
548
+
549
+ return encoder_hidden_states
550
+
551
+
552
+ class TemporalConditionalAttention(Attention):
553
+ def __init__(self, n_frames=8, rotary_emb=False, *args, **kwargs):
554
+ super().__init__(processor=RotaryEmbAttnProcessor2_0() if rotary_emb else None, *args, **kwargs)
555
+
556
+ if not rotary_emb:
557
+ self.pos_enc = PositionalEncoding(self.inner_dim)
558
+ else:
559
+ rotary_bias = RelativePositionBias(heads=kwargs['heads'], max_distance=32)
560
+ self.rotary_bias = rotary_bias
561
+ self.rotary_emb = RotaryEmbedding(self.inner_dim // 2)
562
+
563
+ self.use_rotary_emb = rotary_emb
564
+ self.n_frames = n_frames
565
+
566
+ def forward(
567
+ self,
568
+ hidden_states,
569
+ encoder_hidden_states=None,
570
+ attention_mask=None,
571
+ adjacent_slices=None,
572
+ **cross_attention_kwargs):
573
+
574
+ key_pos_idx = None
575
+
576
+ bt, hw, c = hidden_states.shape
577
+ hidden_states = rearrange(hidden_states, '(b t) hw c -> b hw t c', t=self.n_frames)
578
+ if not self.use_rotary_emb:
579
+ pos_embed = self.pos_enc(self.n_frames)
580
+ hidden_states = hidden_states + pos_embed
581
+ hidden_states = rearrange(hidden_states, 'b hw t c -> (b hw) t c')
582
+
583
+ if encoder_hidden_states is not None:
584
+ assert adjacent_slices is None
585
+ encoder_hidden_states = encoder_hidden_states[::self.n_frames]
586
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b hw) n c', hw=hw)
587
+
588
+ if adjacent_slices is not None:
589
+ assert encoder_hidden_states is None
590
+ adjacent_slices = rearrange(adjacent_slices, 'b c h w n -> b (h w) n c')
591
+ if not self.use_rotary_emb:
592
+ first_frame_pos_embed = pos_embed[0:1, :]
593
+ adjacent_slices = adjacent_slices + first_frame_pos_embed
594
+ else:
595
+ pos_idx = torch.arange(self.n_frames, device=hidden_states.device, dtype=hidden_states.dtype)
596
+ first_frame_pos_pad = torch.zeros(adjacent_slices.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
597
+ key_pos_idx = torch.cat([pos_idx, first_frame_pos_pad], dim=0)
598
+ adjacent_slices = rearrange(adjacent_slices, 'b hw n c -> (b hw) n c')
599
+ encoder_hidden_states = torch.cat([hidden_states, adjacent_slices], dim=1)
600
+
601
+ if not self.use_rotary_emb:
602
+ out = self.processor(
603
+ self,
604
+ hidden_states,
605
+ encoder_hidden_states=encoder_hidden_states,
606
+ attention_mask=attention_mask,
607
+ **cross_attention_kwargs,
608
+ )
609
+ else:
610
+ out = self.processor(
611
+ self,
612
+ hidden_states,
613
+ encoder_hidden_states=encoder_hidden_states,
614
+ attention_mask=attention_mask,
615
+ key_pos_idx=key_pos_idx,
616
+ **cross_attention_kwargs,
617
+ )
618
+
619
+ out = rearrange(out, '(b hw) t c -> (b t) hw c', hw=hw)
620
+
621
+ return out
622
+
623
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers, attention_op=None):
624
+ if use_memory_efficient_attention_xformers:
625
+ try:
626
+ # Make sure we can run the memory efficient attention
627
+ _ = xformers.ops.memory_efficient_attention(
628
+ torch.randn((1, 2, 40), device="cuda"),
629
+ torch.randn((1, 2, 40), device="cuda"),
630
+ torch.randn((1, 2, 40), device="cuda"),
631
+ )
632
+ except Exception as e:
633
+ raise e
634
+ processor = XFormersAttnProcessor(attention_op=attention_op)
635
+ else:
636
+ processor = (
637
+ AttnProcessor2_0()
638
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
639
+ else AttnProcessor()
640
+ )
641
+ self.set_processor(processor)
642
+
643
+
644
+ class PositionalEncoding(nn.Module):
645
+ def __init__(self, dim, max_pos=512):
646
+ super().__init__()
647
+
648
+ pos = torch.arange(max_pos)
649
+
650
+ freq = torch.arange(dim//2) / dim
651
+ freq = (freq * torch.tensor(10000).log()).exp()
652
+
653
+ x = rearrange(pos, 'L -> L 1') / freq
654
+ x = rearrange(x, 'L d -> L d 1')
655
+
656
+ pe = torch.cat((x.sin(), x.cos()), dim=-1)
657
+ self.pe = rearrange(pe, 'L d sc -> L (d sc)')
658
+
659
+ self.dummy = nn.Parameter(torch.rand(1))
660
+
661
+ def forward(self, length):
662
+ enc = self.pe[:length]
663
+ enc = enc.to(self.dummy.device, self.dummy.dtype)
664
+ return enc
665
+
666
+
667
+ # code taken from https://github.com/Vchitect/LaVie/blob/main/base/models/temporal_attention.py
668
+ class RelativePositionBias(nn.Module):
669
+ def __init__(
670
+ self,
671
+ heads=8,
672
+ num_buckets=32,
673
+ max_distance=128,
674
+ ):
675
+ super().__init__()
676
+ self.num_buckets = num_buckets
677
+ self.max_distance = max_distance
678
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
679
+
680
+ @staticmethod
681
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
682
+ ret = 0
683
+ n = -relative_position
684
+
685
+ num_buckets //= 2
686
+ ret += (n < 0).long() * num_buckets
687
+ n = torch.abs(n)
688
+
689
+ max_exact = num_buckets // 2
690
+ is_small = n < max_exact
691
+
692
+ val_if_large = max_exact + (
693
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
694
+ ).long()
695
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
696
+
697
+ ret += torch.where(is_small, n, val_if_large)
698
+ return ret
699
+
700
+ def forward(self, qlen, klen, device, dtype):
701
+ q_pos = torch.arange(qlen, dtype = torch.long, device = device)
702
+ k_pos = torch.arange(klen, dtype = torch.long, device = device)
703
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
704
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
705
+ values = self.relative_attention_bias(rp_bucket)
706
+ values = values.to(device, dtype)
707
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
708
+
709
+
710
+ class RotaryEmbAttnProcessor2_0:
711
+ r"""
712
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
713
+ Add rotary embedding support
714
+ """
715
+
716
+ def __init__(self):
717
+
718
+ if not hasattr(F, "scaled_dot_product_attention"):
719
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
720
+
721
+ def __call__(
722
+ self,
723
+ attn: Attention,
724
+ hidden_states,
725
+ encoder_hidden_states=None,
726
+ attention_mask=None,
727
+ temb=None,
728
+ scale: float = 1.0,
729
+ key_pos_idx: Optional[torch.Tensor] = None,
730
+ ):
731
+ assert attention_mask is None
732
+ residual = hidden_states
733
+
734
+ if attn.spatial_norm is not None:
735
+ hidden_states = attn.spatial_norm(hidden_states, temb)
736
+
737
+ input_ndim = hidden_states.ndim
738
+
739
+ if input_ndim == 4:
740
+ batch_size, channel, height, width = hidden_states.shape
741
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
742
+
743
+ batch_size, sequence_length, _ = (
744
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
745
+ )
746
+
747
+ # if attention_mask is not None:
748
+ # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
749
+ # # scaled_dot_product_attention expects attention_mask shape to be
750
+ # # (batch, heads, source_length, target_length)
751
+ # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
752
+
753
+ if attn.group_norm is not None:
754
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
755
+
756
+ query = attn.to_q(hidden_states, scale=scale)
757
+
758
+ if encoder_hidden_states is None:
759
+ encoder_hidden_states = hidden_states
760
+ elif attn.norm_cross:
761
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
762
+
763
+ qlen = hidden_states.shape[1]
764
+ klen = encoder_hidden_states.shape[1]
765
+ # currently only add bias for self attention. Relative distance doesn't make sense for cross attention.
766
+ # if qlen == klen:
767
+ # time_rel_pos_bias = attn.rotary_bias(qlen, klen, device=hidden_states.device, dtype=hidden_states.dtype)
768
+ # attention_mask = repeat(time_rel_pos_bias, "h d1 d2 -> b h d1 d2", b=batch_size)
769
+
770
+ key = attn.to_k(encoder_hidden_states, scale=scale)
771
+ value = attn.to_v(encoder_hidden_states, scale=scale)
772
+
773
+ query = attn.rotary_emb.rotate_queries_or_keys(query)
774
+ if qlen == klen:
775
+ key = attn.rotary_emb.rotate_queries_or_keys(key)
776
+ elif key_pos_idx is not None:
777
+ key = attn.rotary_emb.rotate_queries_or_keys(key, seq_pos=key_pos_idx)
778
+
779
+ inner_dim = key.shape[-1]
780
+ head_dim = inner_dim // attn.heads
781
+
782
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
783
+
784
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
785
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
786
+
787
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
788
+ # TODO: add support for attn.scale when we move to Torch 2.1
789
+ hidden_states = F.scaled_dot_product_attention(
790
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
791
+ )
792
+
793
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
794
+ hidden_states = hidden_states.to(query.dtype)
795
+
796
+ # linear proj
797
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
798
+ # dropout
799
+ hidden_states = attn.to_out[1](hidden_states)
800
+
801
+ if input_ndim == 4:
802
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
803
+
804
+ if attn.residual_connection:
805
+ hidden_states = hidden_states + residual
806
+
807
+ hidden_states = hidden_states / attn.rescale_output_factor
808
+
809
+ return hidden_states
src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_transformer_blocks.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/v0.21.0/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
13
+ from diffusers.utils import BaseOutput, deprecate
14
+ from diffusers.models.attention import AdaLayerNorm, AdaLayerNormZero, FeedForward, GatedSelfAttentionDense
15
+ from diffusers.models.embeddings import PatchEmbed
16
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+ from diffusers.models.transformer_2d import Transformer2DModelOutput
19
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
20
+ from diffusers.models.attention_processor import Attention
21
+ from diffusers.models.lora import LoRACompatibleLinear
22
+
23
+ from .videoldm_attention import ConditionalAttention, TemporalConditionalAttention
24
+
25
+
26
+ class Transformer2DConditionModel(ModelMixin, ConfigMixin):
27
+ @register_to_config
28
+ def __init__(
29
+ self,
30
+ num_attention_heads: int = 16,
31
+ attention_head_dim: int = 88,
32
+ in_channels: Optional[int] = None,
33
+ out_channels: Optional[int] = None,
34
+ num_layers: int = 1,
35
+ dropout: float = 0.0,
36
+ norm_num_groups: int = 32,
37
+ cross_attention_dim: Optional[int] = None,
38
+ attention_bias: bool = False,
39
+ sample_size: Optional[int] = None,
40
+ num_vector_embeds: Optional[int] = None,
41
+ patch_size: Optional[int] = None,
42
+ activation_fn: str = "geglu",
43
+ num_embeds_ada_norm: Optional[int] = None,
44
+ use_linear_projection: bool = False,
45
+ only_cross_attention: bool = False,
46
+ double_self_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+ norm_type: str = "layer_norm",
49
+ norm_elementwise_affine: bool = True,
50
+ attention_type: str = "default",
51
+ # additional
52
+ n_frames: int = 8,
53
+ is_temporal: bool = False,
54
+ augment_temporal_attention: bool = False,
55
+ rotary_emb=False,
56
+ ):
57
+ super().__init__()
58
+ self.use_linear_projection = use_linear_projection
59
+ self.num_attention_heads = num_attention_heads
60
+ self.attention_head_dim = attention_head_dim
61
+ inner_dim = num_attention_heads * attention_head_dim
62
+
63
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
64
+ # Define whether input is continuous or discrete depending on configuration
65
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
66
+ self.is_input_vectorized = num_vector_embeds is not None
67
+ self.is_input_patches = in_channels is not None and patch_size is not None
68
+
69
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
70
+ deprecation_message = (
71
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
72
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
73
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
74
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
75
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
76
+ )
77
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
78
+ norm_type = "ada_norm"
79
+
80
+ if self.is_input_continuous and self.is_input_vectorized:
81
+ raise ValueError(
82
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
83
+ " sure that either `in_channels` or `num_vector_embeds` is None."
84
+ )
85
+ elif self.is_input_vectorized and self.is_input_patches:
86
+ raise ValueError(
87
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
88
+ " sure that either `num_vector_embeds` or `num_patches` is None."
89
+ )
90
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
91
+ raise ValueError(
92
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
93
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
94
+ )
95
+
96
+ # 2. Define input layers
97
+ if self.is_input_continuous:
98
+ self.in_channels = in_channels
99
+
100
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
101
+ if use_linear_projection:
102
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
103
+ else:
104
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
105
+ elif self.is_input_vectorized:
106
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
107
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
108
+
109
+ self.height = sample_size
110
+ self.width = sample_size
111
+ self.num_vector_embeds = num_vector_embeds
112
+ self.num_latent_pixels = self.height * self.width
113
+
114
+ self.latent_image_embedding = ImagePositionalEmbeddings(
115
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
116
+ )
117
+ elif self.is_input_patches:
118
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
119
+
120
+ self.height = sample_size
121
+ self.width = sample_size
122
+
123
+ self.patch_size = patch_size
124
+ self.pos_embed = PatchEmbed(
125
+ height=sample_size,
126
+ width=sample_size,
127
+ patch_size=patch_size,
128
+ in_channels=in_channels,
129
+ embed_dim=inner_dim,
130
+ )
131
+
132
+ # 3. Define transformers blocks
133
+ self.transformer_blocks = nn.ModuleList(
134
+ [
135
+ BasicConditionalTransformerBlock(
136
+ inner_dim,
137
+ num_attention_heads,
138
+ attention_head_dim,
139
+ dropout=dropout,
140
+ cross_attention_dim=cross_attention_dim,
141
+ activation_fn=activation_fn,
142
+ num_embeds_ada_norm=num_embeds_ada_norm,
143
+ attention_bias=attention_bias,
144
+ only_cross_attention=only_cross_attention,
145
+ double_self_attention=double_self_attention,
146
+ upcast_attention=upcast_attention,
147
+ norm_type=norm_type,
148
+ norm_elementwise_affine=norm_elementwise_affine,
149
+ attention_type=attention_type,
150
+ # additional
151
+ n_frames=n_frames,
152
+ is_temporal=is_temporal,
153
+ augment_temporal_attention=augment_temporal_attention,
154
+ rotary_emb=rotary_emb,
155
+ )
156
+ for d in range(num_layers)
157
+ ]
158
+ )
159
+
160
+ # 4. Define output layers
161
+ self.out_channels = in_channels if out_channels is None else out_channels
162
+ if self.is_input_continuous:
163
+ # TODO: should use out_channels for continuous projections
164
+ if use_linear_projection:
165
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
166
+ else:
167
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
168
+ elif self.is_input_vectorized:
169
+ self.norm_out = nn.LayerNorm(inner_dim)
170
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
171
+ elif self.is_input_patches:
172
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
173
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
174
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
175
+
176
+ self.alpha = None
177
+ if is_temporal:
178
+ self.alpha = nn.Parameter(torch.ones(1))
179
+
180
+ self.gradient_checkpointing = False
181
+
182
+ def forward(
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ encoder_hidden_states: Optional[torch.Tensor] = None,
186
+ timestep: Optional[torch.LongTensor] = None,
187
+ class_labels: Optional[torch.LongTensor] = None,
188
+ cross_attention_kwargs: Dict[str, Any] = None,
189
+ attention_mask: Optional[torch.Tensor] = None,
190
+ encoder_attention_mask: Optional[torch.Tensor] = None,
191
+ return_dict: bool = True,
192
+ condition_on_first_frame: bool = False,
193
+ ):
194
+ input_states = hidden_states
195
+ input_height, input_width = hidden_states.shape[-2:]
196
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
197
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
198
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
199
+ # expects mask of shape:
200
+ # [batch, key_tokens]
201
+ # adds singleton query_tokens dimension:
202
+ # [batch, 1, key_tokens]
203
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
204
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
205
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
206
+ if attention_mask is not None and attention_mask.ndim == 2:
207
+ # assume that mask is expressed as:
208
+ # (1 = keep, 0 = discard)
209
+ # convert mask into a bias that can be added to attention scores:
210
+ # (keep = +0, discard = -10000.0)
211
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
212
+ attention_mask = attention_mask.unsqueeze(1)
213
+
214
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
215
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
216
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
217
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
218
+
219
+ # Retrieve lora scale.
220
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
221
+
222
+ # 1. Input
223
+ if self.is_input_continuous:
224
+ batch, _, height, width = hidden_states.shape
225
+ residual = hidden_states
226
+
227
+ hidden_states = self.norm(hidden_states)
228
+ if not self.use_linear_projection:
229
+ hidden_states = self.proj_in(hidden_states, lora_scale)
230
+ inner_dim = hidden_states.shape[1]
231
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
232
+ else:
233
+ inner_dim = hidden_states.shape[1]
234
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
235
+ hidden_states = self.proj_in(hidden_states, scale=lora_scale)
236
+
237
+ elif self.is_input_vectorized:
238
+ hidden_states = self.latent_image_embedding(hidden_states)
239
+ elif self.is_input_patches:
240
+ hidden_states = self.pos_embed(hidden_states)
241
+
242
+ # 2. Blocks
243
+ for block in self.transformer_blocks:
244
+ if self.training and self.gradient_checkpointing:
245
+ hidden_states = torch.utils.checkpoint.checkpoint(
246
+ block,
247
+ hidden_states,
248
+ attention_mask,
249
+ encoder_hidden_states,
250
+ encoder_attention_mask,
251
+ timestep,
252
+ cross_attention_kwargs,
253
+ class_labels,
254
+ use_reentrant=False,
255
+ )
256
+ else:
257
+ hidden_states = block(
258
+ hidden_states,
259
+ attention_mask=attention_mask,
260
+ encoder_hidden_states=encoder_hidden_states,
261
+ encoder_attention_mask=encoder_attention_mask,
262
+ timestep=timestep,
263
+ cross_attention_kwargs=cross_attention_kwargs,
264
+ class_labels=class_labels,
265
+ # additional
266
+ condition_on_first_frame=condition_on_first_frame,
267
+ input_height=input_height,
268
+ input_width=input_width,
269
+ )
270
+
271
+ # 3. Output
272
+ if self.is_input_continuous:
273
+ if not self.use_linear_projection:
274
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
275
+ hidden_states = self.proj_out(hidden_states, scale=lora_scale)
276
+ else:
277
+ hidden_states = self.proj_out(hidden_states, scale=lora_scale)
278
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
279
+
280
+ output = hidden_states + residual
281
+ elif self.is_input_vectorized:
282
+ hidden_states = self.norm_out(hidden_states)
283
+ logits = self.out(hidden_states)
284
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
285
+ logits = logits.permute(0, 2, 1)
286
+
287
+ # log(p(x_0))
288
+ output = F.log_softmax(logits.double(), dim=1).float()
289
+ elif self.is_input_patches:
290
+ # TODO: cleanup!
291
+ conditioning = self.transformer_blocks[0].norm1.emb(
292
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
293
+ )
294
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
295
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
296
+ hidden_states = self.proj_out_2(hidden_states)
297
+
298
+ # unpatchify
299
+ height = width = int(hidden_states.shape[1] ** 0.5)
300
+ hidden_states = hidden_states.reshape(
301
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
302
+ )
303
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
304
+ output = hidden_states.reshape(
305
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
306
+ )
307
+
308
+ if self.alpha is not None:
309
+ with torch.no_grad():
310
+ self.alpha.clamp_(0, 1)
311
+
312
+ output = self.alpha * input_states + (1 - self.alpha) * output
313
+
314
+ if not return_dict:
315
+ return (output,)
316
+
317
+ return Transformer2DModelOutput(sample=output)
318
+
319
+
320
+ @maybe_allow_in_graph
321
+ class BasicConditionalTransformerBlock(nn.Module):
322
+ """ transformer block with first frame conditioning """
323
+ def __init__(
324
+ self,
325
+ dim: int,
326
+ num_attention_heads: int,
327
+ attention_head_dim: int,
328
+ dropout=0.0,
329
+ cross_attention_dim: Optional[int] = None,
330
+ activation_fn: str = "geglu",
331
+ num_embeds_ada_norm: Optional[int] = None,
332
+ attention_bias: bool = False,
333
+ only_cross_attention: bool = False,
334
+ double_self_attention: bool = False,
335
+ upcast_attention: bool = False,
336
+ norm_elementwise_affine: bool = True,
337
+ norm_type: str = "layer_norm",
338
+ final_dropout: bool = False,
339
+ attention_type: str = "default",
340
+ # additional
341
+ n_frames: int = 8,
342
+ is_temporal: bool = False,
343
+ augment_temporal_attention: bool = False,
344
+ rotary_emb=False,
345
+ ):
346
+ super().__init__()
347
+ self.n_frames = n_frames
348
+ self.only_cross_attention = only_cross_attention
349
+ self.augment_temporal_attention = augment_temporal_attention
350
+ self.is_temporal = is_temporal
351
+
352
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
353
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
354
+
355
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
356
+ raise ValueError(
357
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
358
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
359
+ )
360
+
361
+ # Define 3 blocks. Each block has its own normalization layer.
362
+ # 1. Self-Attn
363
+ if self.use_ada_layer_norm:
364
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
365
+ elif self.use_ada_layer_norm_zero:
366
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
367
+ else:
368
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
369
+
370
+ if not is_temporal:
371
+ self.attn1 = ConditionalAttention(
372
+ query_dim=dim,
373
+ heads=num_attention_heads,
374
+ dim_head=attention_head_dim,
375
+ dropout=dropout,
376
+ bias=attention_bias,
377
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
378
+ upcast_attention=upcast_attention,
379
+ )
380
+ else:
381
+ self.attn1 = TemporalConditionalAttention(
382
+ query_dim=dim,
383
+ heads=num_attention_heads,
384
+ dim_head=attention_head_dim,
385
+ dropout=dropout,
386
+ bias=attention_bias,
387
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
388
+ upcast_attention=upcast_attention,
389
+ # additional
390
+ n_frames=n_frames,
391
+ rotary_emb=rotary_emb,
392
+ )
393
+
394
+ # 2. Cross-Attn
395
+ if cross_attention_dim is not None or double_self_attention:
396
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
397
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
398
+ # the second cross attention block.
399
+ self.norm2 = (
400
+ AdaLayerNorm(dim, num_embeds_ada_norm)
401
+ if self.use_ada_layer_norm
402
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
403
+ )
404
+ if not is_temporal:
405
+ self.attn2 = ConditionalAttention(
406
+ query_dim=dim,
407
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
408
+ heads=num_attention_heads,
409
+ dim_head=attention_head_dim,
410
+ dropout=dropout,
411
+ bias=attention_bias,
412
+ upcast_attention=upcast_attention,
413
+ ) # is self-attn if encoder_hidden_states is none
414
+ else:
415
+ self.attn2 = TemporalConditionalAttention(
416
+ query_dim=dim,
417
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
418
+ heads=num_attention_heads,
419
+ dim_head=attention_head_dim,
420
+ dropout=dropout,
421
+ bias=attention_bias,
422
+ upcast_attention=upcast_attention,
423
+ # additional
424
+ n_frames=n_frames,
425
+ rotary_emb=rotary_emb,
426
+ )
427
+ else:
428
+ self.norm2 = None
429
+ self.attn2 = None
430
+
431
+ # 3. Feed-forward
432
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
433
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
434
+
435
+ # 4. Fuser
436
+ if attention_type == "gated" or attention_type == "gated-text-image":
437
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
438
+
439
+ # let chunk size default to None
440
+ self._chunk_size = None
441
+ self._chunk_dim = 0
442
+
443
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
444
+ # Sets chunk feed-forward
445
+ self._chunk_size = chunk_size
446
+ self._chunk_dim = dim
447
+
448
+ def forward(
449
+ self,
450
+ hidden_states: torch.FloatTensor,
451
+ attention_mask: Optional[torch.FloatTensor] = None,
452
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
453
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
454
+ timestep: Optional[torch.LongTensor] = None,
455
+ cross_attention_kwargs: Dict[str, Any] = None,
456
+ class_labels: Optional[torch.LongTensor] = None,
457
+ condition_on_first_frame: bool = False,
458
+ input_height: Optional[int] = None,
459
+ input_width: Optional[int] = None,
460
+ ):
461
+ # Notice that normalization is always applied before the real computation in the following blocks.
462
+ # 0. Self-Attention
463
+ if self.use_ada_layer_norm:
464
+ norm_hidden_states = self.norm1(hidden_states, timestep)
465
+ elif self.use_ada_layer_norm_zero:
466
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
467
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
468
+ )
469
+ else:
470
+ norm_hidden_states = self.norm1(hidden_states)
471
+
472
+ # 1. Retrieve lora scale.
473
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
474
+
475
+ # 2. Prepare GLIGEN inputs
476
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
477
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
478
+
479
+ if condition_on_first_frame:
480
+ first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :]
481
+ first_frame_hidden_states = repeat(first_frame_hidden_states, 'b d h -> b f d h', f=self.n_frames)
482
+ first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b f d h -> (b f) d h')
483
+ first_frame_concat_hidden_states = torch.cat((norm_hidden_states, first_frame_hidden_states), dim=1)
484
+ attn_output = self.attn1(
485
+ norm_hidden_states,
486
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else first_frame_concat_hidden_states,
487
+ attention_mask=attention_mask,
488
+ **cross_attention_kwargs,
489
+ )
490
+ elif self.is_temporal and self.augment_temporal_attention:
491
+ first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :]
492
+ first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b (h w) c -> b h w c', h=input_height, w=input_width)
493
+ first_frame_hidden_states = first_frame_hidden_states.permute(0, 3, 1, 2)
494
+ padded_first_frame = torch.nn.functional.pad(first_frame_hidden_states, (1, 1, 1, 1), "replicate")
495
+ first_frame_windows = padded_first_frame.unfold(2, 3, 1).unfold(3, 3, 1)
496
+ mask = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.bool)
497
+ adjacent_slices = first_frame_windows[:, :, :, :, mask]
498
+ attn_output = self.attn1(
499
+ norm_hidden_states,
500
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
501
+ attention_mask=attention_mask,
502
+ adjacent_slices=adjacent_slices,
503
+ **cross_attention_kwargs,
504
+ )
505
+ else:
506
+ attn_output = self.attn1(
507
+ norm_hidden_states,
508
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
509
+ attention_mask=attention_mask,
510
+ **cross_attention_kwargs,
511
+ )
512
+ if self.use_ada_layer_norm_zero:
513
+ attn_output = gate_msa.unsqueeze(1) * attn_output
514
+ hidden_states = attn_output + hidden_states
515
+
516
+ # 2.5 GLIGEN Control
517
+ if gligen_kwargs is not None:
518
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
519
+ # 2.5 ends
520
+
521
+ # 3. Cross-Attention
522
+ if self.attn2 is not None:
523
+ norm_hidden_states = (
524
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
525
+ )
526
+
527
+ attn_output = self.attn2(
528
+ norm_hidden_states,
529
+ encoder_hidden_states=encoder_hidden_states,
530
+ attention_mask=encoder_attention_mask,
531
+ **cross_attention_kwargs,
532
+ )
533
+ hidden_states = attn_output + hidden_states
534
+
535
+ # 4. Feed-forward
536
+ norm_hidden_states = self.norm3(hidden_states)
537
+
538
+ if self.use_ada_layer_norm_zero:
539
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
540
+
541
+ if self._chunk_size is not None:
542
+ # "feed_forward_chunk_size" can be used to save memory
543
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
544
+ raise ValueError(
545
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
546
+ )
547
+
548
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
549
+ ff_output = torch.cat(
550
+ [
551
+ self.ff(hid_slice, scale=lora_scale)
552
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
553
+ ],
554
+ dim=self._chunk_dim,
555
+ )
556
+ else:
557
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
558
+
559
+ if self.use_ada_layer_norm_zero:
560
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
561
+
562
+ hidden_states = ff_output + hidden_states
563
+
564
+ return hidden_states
src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet.py ADDED
@@ -0,0 +1,1371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import Optional, Tuple, Union, Dict, List, Any
4
+ from einops import rearrange, repeat
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from diffusers.loaders import UNet2DConditionLoadersMixin
9
+ from diffusers.models import ModelMixin
10
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
11
+ from diffusers.models.unet_2d_blocks import UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn
12
+ from diffusers.models.embeddings import (
13
+ GaussianFourierProjection,
14
+ ImageHintTimeEmbedding,
15
+ ImageProjection,
16
+ ImageTimeEmbedding,
17
+ PositionNet,
18
+ TextImageProjection,
19
+ TextImageTimeEmbedding,
20
+ TextTimeEmbedding,
21
+ TimestepEmbedding,
22
+ Timesteps,
23
+ )
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.activations import get_activation
32
+ from diffusers.configuration_utils import register_to_config, ConfigMixin
33
+ from diffusers.models.modeling_utils import load_state_dict, load_model_dict_into_meta
34
+ from diffusers.utils import (
35
+ CONFIG_NAME,
36
+ DIFFUSERS_CACHE,
37
+ FLAX_WEIGHTS_NAME,
38
+ HF_HUB_OFFLINE,
39
+ SAFETENSORS_WEIGHTS_NAME,
40
+ WEIGHTS_NAME,
41
+ _add_variant,
42
+ _get_model_file,
43
+ deprecate,
44
+ is_accelerate_available,
45
+ is_torch_version,
46
+ logging,
47
+ )
48
+ from diffusers import __version__
49
+
50
+ if is_torch_version(">=", "1.9.0"):
51
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
52
+ else:
53
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
54
+
55
+
56
+ if is_accelerate_available():
57
+ import accelerate
58
+ from accelerate.utils import set_module_tensor_to_device
59
+ from accelerate.utils.versions import is_torch_version
60
+
61
+
62
+
63
+ from .videoldm_unet_blocks import get_down_block, get_up_block, VideoLDMUNetMidBlock2DCrossAttn
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+
68
+ class VideoLDMUNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
69
+ _supports_gradient_checkpointing = True
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ sample_size: Optional[int] = None,
74
+ in_channels: int = 4,
75
+ out_channels: int = 4,
76
+ center_input_sample: bool = False,
77
+ flip_sin_to_cos: bool = True,
78
+ freq_shift: int = 0,
79
+ down_block_types: Tuple[str] = (
80
+ "CrossAttnDownBlock2D", # -> VideoLDMDownBlock
81
+ "CrossAttnDownBlock2D", # -> VideoLDMDownBlock
82
+ "CrossAttnDownBlock2D", # -> VideoLDMDownBlock
83
+ "DownBlock2D",
84
+ ),
85
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
86
+ up_block_types: Tuple[str] = (
87
+ "UpBlock2D",
88
+ "CrossAttnUpBlock2D", # -> VideoLDMUpBlock
89
+ "CrossAttnUpBlock2D", # -> VideoLDMUpBlock
90
+ "CrossAttnUpBlock2D", # -> VideoLDMUpBlock
91
+ ),
92
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
93
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
94
+ layers_per_block: Union[int, Tuple[int]] = 2,
95
+ downsample_padding: int = 1,
96
+ mid_block_scale_factor: float = 1,
97
+ dropout: float = 0.0,
98
+ act_fn: str = "silu",
99
+ norm_num_groups: Optional[int] = 32,
100
+ norm_eps: float = 1e-5,
101
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
102
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
103
+ encoder_hid_dim: Optional[int] = None,
104
+ encoder_hid_dim_type: Optional[str] = None,
105
+ attention_head_dim: Union[int, Tuple[int]] = 8,
106
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
107
+ dual_cross_attention: bool = False,
108
+ use_linear_projection: bool = False,
109
+ class_embed_type: Optional[str] = None,
110
+ addition_embed_type: Optional[str] = None,
111
+ addition_time_embed_dim: Optional[int] = None,
112
+ num_class_embeds: Optional[int] = None,
113
+ upcast_attention: bool = False,
114
+ resnet_time_scale_shift: str = "default",
115
+ resnet_skip_time_act: bool = False,
116
+ resnet_out_scale_factor: int = 1.0,
117
+ time_embedding_type: str = "positional",
118
+ time_embedding_dim: Optional[int] = None,
119
+ time_embedding_act_fn: Optional[str] = None,
120
+ timestep_post_act: Optional[str] = None,
121
+ time_cond_proj_dim: Optional[int] = None,
122
+ conv_in_kernel: int = 3,
123
+ conv_out_kernel: int = 3,
124
+ projection_class_embeddings_input_dim: Optional[int] = None,
125
+ attention_type: str = "default",
126
+ class_embeddings_concat: bool = False,
127
+ mid_block_only_cross_attention: Optional[bool] = None,
128
+ cross_attention_norm: Optional[str] = None,
129
+ addition_embed_type_num_heads=64,
130
+ # additional
131
+ use_temporal: bool = True,
132
+ n_frames: int = 8,
133
+ n_temp_heads: int = 8,
134
+ first_frame_condition_mode: str = "none",
135
+ augment_temporal_attention: bool = False,
136
+ temp_pos_embedding: str = "sinusoidal",
137
+ use_frame_stride_condition: bool = False,
138
+ ):
139
+ super().__init__()
140
+
141
+ rotary_emb = False
142
+ if temp_pos_embedding == "rotary":
143
+ # from rotary_embedding_torch import RotaryEmbedding
144
+ # rotary_emb = RotaryEmbedding(32)
145
+ # self.rotary_emb = rotary_emb
146
+ rotary_emb = True
147
+ self.rotary_emb = rotary_emb
148
+
149
+ self.use_temporal = use_temporal
150
+ self.augment_temporal_attention = augment_temporal_attention
151
+
152
+ assert first_frame_condition_mode in ["none", "concat", "conv2d", "input_only"], f"first_frame_condition_mode: {first_frame_condition_mode} must be one of ['none', 'concat', 'conv2d', 'input_only']"
153
+ self.first_frame_condition_mode = first_frame_condition_mode
154
+ latent_channels = in_channels
155
+
156
+ self.sample_size = sample_size
157
+
158
+ if num_attention_heads is not None:
159
+ raise ValueError(
160
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
161
+ )
162
+
163
+ num_attention_heads = num_attention_heads or attention_head_dim
164
+
165
+ # Check inputs
166
+ if len(down_block_types) != len(up_block_types):
167
+ raise ValueError(
168
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
169
+ )
170
+
171
+ if len(block_out_channels) != len(down_block_types):
172
+ raise ValueError(
173
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
174
+ )
175
+
176
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
177
+ raise ValueError(
178
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
179
+ )
180
+
181
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
182
+ raise ValueError(
183
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
184
+ )
185
+
186
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
187
+ raise ValueError(
188
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
189
+ )
190
+
191
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
192
+ raise ValueError(
193
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
194
+ )
195
+
196
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
197
+ raise ValueError(
198
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
199
+ )
200
+
201
+ # input
202
+ conv_in_padding = (conv_in_kernel - 1) // 2
203
+ self.conv_in = nn.Conv2d(
204
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
205
+ )
206
+
207
+ # time
208
+ if time_embedding_type == "fourier":
209
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
210
+ if time_embed_dim % 2 != 0:
211
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
212
+ self.time_proj = GaussianFourierProjection(
213
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
214
+ )
215
+ timestep_input_dim = time_embed_dim
216
+ elif time_embedding_type == "positional":
217
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
218
+
219
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
220
+ timestep_input_dim = block_out_channels[0]
221
+ else:
222
+ raise ValueError(
223
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
224
+ )
225
+
226
+ self.time_embedding = TimestepEmbedding(
227
+ timestep_input_dim,
228
+ time_embed_dim,
229
+ act_fn=act_fn,
230
+ post_act_fn=timestep_post_act,
231
+ cond_proj_dim=time_cond_proj_dim,
232
+ )
233
+
234
+ self.use_frame_stride_condition = use_frame_stride_condition
235
+ if self.use_frame_stride_condition:
236
+ self.frame_stride_embedding = TimestepEmbedding(
237
+ timestep_input_dim,
238
+ time_embed_dim,
239
+ act_fn=act_fn,
240
+ post_act_fn=timestep_post_act,
241
+ cond_proj_dim=time_cond_proj_dim,
242
+ )
243
+ # zero init
244
+ nn.init.zeros_(self.frame_stride_embedding.linear_2.weight)
245
+ nn.init.zeros_(self.frame_stride_embedding.linear_2.bias)
246
+
247
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
248
+ encoder_hid_dim_type = "text_proj"
249
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
250
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
251
+
252
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
253
+ raise ValueError(
254
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
255
+ )
256
+
257
+ if encoder_hid_dim_type == "text_proj":
258
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
259
+ elif encoder_hid_dim_type == "text_image_proj":
260
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
261
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
262
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
263
+ self.encoder_hid_proj = TextImageProjection(
264
+ text_embed_dim=encoder_hid_dim,
265
+ image_embed_dim=cross_attention_dim,
266
+ cross_attention_dim=cross_attention_dim,
267
+ )
268
+ elif encoder_hid_dim_type == "image_proj":
269
+ # Kandinsky 2.2
270
+ self.encoder_hid_proj = ImageProjection(
271
+ image_embed_dim=encoder_hid_dim,
272
+ cross_attention_dim=cross_attention_dim,
273
+ )
274
+ elif encoder_hid_dim_type is not None:
275
+ raise ValueError(
276
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
277
+ )
278
+ else:
279
+ self.encoder_hid_proj = None
280
+
281
+ # class embedding
282
+ if class_embed_type is None and num_class_embeds is not None:
283
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
284
+ elif class_embed_type == "timestep":
285
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
286
+ elif class_embed_type == "identity":
287
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
288
+ elif class_embed_type == "projection":
289
+ if projection_class_embeddings_input_dim is None:
290
+ raise ValueError(
291
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
292
+ )
293
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
294
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
295
+ # 2. it projects from an arbitrary input dimension.
296
+ #
297
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
298
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
299
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
300
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
301
+ elif class_embed_type == "simple_projection":
302
+ if projection_class_embeddings_input_dim is None:
303
+ raise ValueError(
304
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
305
+ )
306
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
307
+ else:
308
+ self.class_embedding = None
309
+
310
+ if addition_embed_type == "text":
311
+ if encoder_hid_dim is not None:
312
+ text_time_embedding_from_dim = encoder_hid_dim
313
+ else:
314
+ text_time_embedding_from_dim = cross_attention_dim
315
+
316
+ self.add_embedding = TextTimeEmbedding(
317
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
318
+ )
319
+ elif addition_embed_type == "text_image":
320
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
321
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
322
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
323
+ self.add_embedding = TextImageTimeEmbedding(
324
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
325
+ )
326
+ elif addition_embed_type == "text_time":
327
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
328
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
329
+ elif addition_embed_type == "image":
330
+ # Kandinsky 2.2
331
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
332
+ elif addition_embed_type == "image_hint":
333
+ # Kandinsky 2.2 ControlNet
334
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
335
+ elif addition_embed_type is not None:
336
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
337
+
338
+ if time_embedding_act_fn is None:
339
+ self.time_embed_act = None
340
+ else:
341
+ self.time_embed_act = get_activation(time_embedding_act_fn)
342
+
343
+ self.down_blocks = nn.ModuleList([])
344
+ self.up_blocks = nn.ModuleList([])
345
+
346
+ if isinstance(only_cross_attention, bool):
347
+ if mid_block_only_cross_attention is None:
348
+ mid_block_only_cross_attention = only_cross_attention
349
+
350
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
351
+
352
+ if mid_block_only_cross_attention is None:
353
+ mid_block_only_cross_attention = False
354
+
355
+ if isinstance(num_attention_heads, int):
356
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
357
+
358
+ if isinstance(attention_head_dim, int):
359
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
360
+
361
+ if isinstance(cross_attention_dim, int):
362
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
363
+
364
+ if isinstance(layers_per_block, int):
365
+ layers_per_block = [layers_per_block] * len(down_block_types)
366
+
367
+ if isinstance(transformer_layers_per_block, int):
368
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
369
+
370
+ if class_embeddings_concat:
371
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
372
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
373
+ # regular time embeddings
374
+ blocks_time_embed_dim = time_embed_dim * 2
375
+ else:
376
+ blocks_time_embed_dim = time_embed_dim
377
+ # down
378
+ output_channel = block_out_channels[0]
379
+ for i, down_block_type in enumerate(down_block_types):
380
+ input_channel = output_channel
381
+ output_channel = block_out_channels[i]
382
+ is_final_block = i == len(block_out_channels) - 1
383
+
384
+ down_block = get_down_block(
385
+ down_block_type,
386
+ num_layers=layers_per_block[i],
387
+ transformer_layers_per_block=transformer_layers_per_block[i],
388
+ in_channels=input_channel,
389
+ out_channels=output_channel,
390
+ temb_channels=blocks_time_embed_dim,
391
+ add_downsample=not is_final_block,
392
+ resnet_eps=norm_eps,
393
+ resnet_act_fn=act_fn,
394
+ resnet_groups=norm_num_groups,
395
+ cross_attention_dim=cross_attention_dim[i],
396
+ num_attention_heads=num_attention_heads[i],
397
+ downsample_padding=downsample_padding,
398
+ dual_cross_attention=dual_cross_attention,
399
+ use_linear_projection=use_linear_projection,
400
+ only_cross_attention=only_cross_attention[i],
401
+ upcast_attention=upcast_attention,
402
+ resnet_time_scale_shift=resnet_time_scale_shift,
403
+ attention_type=attention_type,
404
+ resnet_skip_time_act=resnet_skip_time_act,
405
+ resnet_out_scale_factor=resnet_out_scale_factor,
406
+ cross_attention_norm=cross_attention_norm,
407
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
408
+ dropout=dropout,
409
+ # additional
410
+ use_temporal=use_temporal,
411
+ augment_temporal_attention=augment_temporal_attention,
412
+ n_frames=n_frames,
413
+ n_temp_heads=n_temp_heads,
414
+ first_frame_condition_mode=first_frame_condition_mode,
415
+ latent_channels=latent_channels,
416
+ rotary_emb=rotary_emb,
417
+ )
418
+ self.down_blocks.append(down_block)
419
+
420
+ # mid
421
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
422
+ self.mid_block = VideoLDMUNetMidBlock2DCrossAttn(
423
+ transformer_layers_per_block=transformer_layers_per_block[-1],
424
+ in_channels=block_out_channels[-1],
425
+ temb_channels=blocks_time_embed_dim,
426
+ dropout=dropout,
427
+ resnet_eps=norm_eps,
428
+ resnet_act_fn=act_fn,
429
+ output_scale_factor=mid_block_scale_factor,
430
+ resnet_time_scale_shift=resnet_time_scale_shift,
431
+ cross_attention_dim=cross_attention_dim[-1],
432
+ num_attention_heads=num_attention_heads[-1],
433
+ resnet_groups=norm_num_groups,
434
+ dual_cross_attention=dual_cross_attention,
435
+ use_linear_projection=use_linear_projection,
436
+ upcast_attention=upcast_attention,
437
+ attention_type=attention_type,
438
+ # additional
439
+ use_temporal=use_temporal,
440
+ n_frames=n_frames,
441
+ first_frame_condition_mode=first_frame_condition_mode,
442
+ latent_channels=latent_channels,
443
+ )
444
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
445
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
446
+ in_channels=block_out_channels[-1],
447
+ temb_channels=blocks_time_embed_dim,
448
+ dropout=dropout,
449
+ resnet_eps=norm_eps,
450
+ resnet_act_fn=act_fn,
451
+ output_scale_factor=mid_block_scale_factor,
452
+ cross_attention_dim=cross_attention_dim[-1],
453
+ attention_head_dim=attention_head_dim[-1],
454
+ resnet_groups=norm_num_groups,
455
+ resnet_time_scale_shift=resnet_time_scale_shift,
456
+ skip_time_act=resnet_skip_time_act,
457
+ only_cross_attention=mid_block_only_cross_attention,
458
+ cross_attention_norm=cross_attention_norm,
459
+ )
460
+ elif mid_block_type is None:
461
+ self.mid_block = None
462
+ else:
463
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
464
+
465
+ # count how many layers upsample the images
466
+ self.num_upsamplers = 0
467
+
468
+ # up
469
+ reversed_block_out_channels = list(reversed(block_out_channels))
470
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
471
+ reversed_layers_per_block = list(reversed(layers_per_block))
472
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
473
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
474
+ only_cross_attention = list(reversed(only_cross_attention))
475
+
476
+ output_channel = reversed_block_out_channels[0]
477
+ for i, up_block_type in enumerate(up_block_types):
478
+ is_final_block = i == len(block_out_channels) - 1
479
+
480
+ prev_output_channel = output_channel
481
+ output_channel = reversed_block_out_channels[i]
482
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
483
+
484
+ # add upsample block for all BUT final layer
485
+ if not is_final_block:
486
+ add_upsample = True
487
+ self.num_upsamplers += 1
488
+ else:
489
+ add_upsample = False
490
+
491
+ up_block = get_up_block(
492
+ up_block_type,
493
+ num_layers=reversed_layers_per_block[i] + 1,
494
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
495
+ in_channels=input_channel,
496
+ out_channels=output_channel,
497
+ prev_output_channel=prev_output_channel,
498
+ temb_channels=blocks_time_embed_dim,
499
+ add_upsample=add_upsample,
500
+ resnet_eps=norm_eps,
501
+ resnet_act_fn=act_fn,
502
+ resnet_groups=norm_num_groups,
503
+ cross_attention_dim=reversed_cross_attention_dim[i],
504
+ num_attention_heads=reversed_num_attention_heads[i],
505
+ dual_cross_attention=dual_cross_attention,
506
+ use_linear_projection=use_linear_projection,
507
+ only_cross_attention=only_cross_attention[i],
508
+ upcast_attention=upcast_attention,
509
+ resnet_time_scale_shift=resnet_time_scale_shift,
510
+ attention_type=attention_type,
511
+ resnet_skip_time_act=resnet_skip_time_act,
512
+ resnet_out_scale_factor=resnet_out_scale_factor,
513
+ cross_attention_norm=cross_attention_norm,
514
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
515
+ dropout=dropout,
516
+ # additional
517
+ use_temporal=use_temporal,
518
+ augment_temporal_attention=augment_temporal_attention,
519
+ n_frames=n_frames,
520
+ n_temp_heads=n_temp_heads,
521
+ first_frame_condition_mode=first_frame_condition_mode,
522
+ latent_channels=latent_channels,
523
+ rotary_emb=rotary_emb,
524
+ )
525
+ self.up_blocks.append(up_block)
526
+ prev_output_channel = output_channel
527
+
528
+ # out
529
+ if norm_num_groups is not None:
530
+ self.conv_norm_out = nn.GroupNorm(
531
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
532
+ )
533
+
534
+ self.conv_act = get_activation(act_fn)
535
+
536
+ else:
537
+ self.conv_norm_out = None
538
+ self.conv_act = None
539
+
540
+ conv_out_padding = (conv_out_kernel - 1) // 2
541
+ self.conv_out = nn.Conv2d(
542
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
543
+ )
544
+
545
+ @property
546
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
547
+ r"""
548
+ Returns:
549
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
550
+ indexed by its weight name.
551
+ """
552
+ # set recursively
553
+ processors = {}
554
+
555
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
556
+ if hasattr(module, "get_processor"):
557
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
558
+
559
+ for sub_name, child in module.named_children():
560
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
561
+
562
+ return processors
563
+
564
+ for name, module in self.named_children():
565
+ fn_recursive_add_processors(name, module, processors)
566
+
567
+ return processors
568
+
569
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
570
+ r"""
571
+ Sets the attention processor to use to compute attention.
572
+
573
+ Parameters:
574
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
575
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
576
+ for **all** `Attention` layers.
577
+
578
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
579
+ processor. This is strongly recommended when setting trainable attention processors.
580
+
581
+ """
582
+ count = len(self.attn_processors.keys())
583
+
584
+ if isinstance(processor, dict) and len(processor) != count:
585
+ raise ValueError(
586
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
587
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
588
+ )
589
+
590
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
591
+ if hasattr(module, "set_processor"):
592
+ if not isinstance(processor, dict):
593
+ module.set_processor(processor)
594
+ else:
595
+ module.set_processor(processor.pop(f"{name}.processor"))
596
+
597
+ for sub_name, child in module.named_children():
598
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
599
+
600
+ for name, module in self.named_children():
601
+ fn_recursive_attn_processor(name, module, processor)
602
+
603
+ def set_default_attn_processor(self):
604
+ """
605
+ Disables custom attention processors and sets the default attention implementation.
606
+ """
607
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
608
+ processor = AttnAddedKVProcessor()
609
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
610
+ processor = AttnProcessor()
611
+ else:
612
+ raise ValueError(
613
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
614
+ )
615
+
616
+ self.set_attn_processor(processor)
617
+
618
+ def set_attention_slice(self, slice_size):
619
+ r"""
620
+ Enable sliced attention computation.
621
+
622
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
623
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
624
+
625
+ Args:
626
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
627
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
628
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
629
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
630
+ must be a multiple of `slice_size`.
631
+ """
632
+ sliceable_head_dims = []
633
+
634
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
635
+ if hasattr(module, "set_attention_slice"):
636
+ sliceable_head_dims.append(module.sliceable_head_dim)
637
+
638
+ for child in module.children():
639
+ fn_recursive_retrieve_sliceable_dims(child)
640
+
641
+ # retrieve number of attention layers
642
+ for module in self.children():
643
+ fn_recursive_retrieve_sliceable_dims(module)
644
+
645
+ num_sliceable_layers = len(sliceable_head_dims)
646
+
647
+ if slice_size == "auto":
648
+ # half the attention head size is usually a good trade-off between
649
+ # speed and memory
650
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
651
+ elif slice_size == "max":
652
+ # make smallest slice possible
653
+ slice_size = num_sliceable_layers * [1]
654
+
655
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
656
+
657
+ if len(slice_size) != len(sliceable_head_dims):
658
+ raise ValueError(
659
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
660
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
661
+ )
662
+
663
+ for i in range(len(slice_size)):
664
+ size = slice_size[i]
665
+ dim = sliceable_head_dims[i]
666
+ if size is not None and size > dim:
667
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
668
+
669
+ # Recursively walk through all the children.
670
+ # Any children which exposes the set_attention_slice method
671
+ # gets the message
672
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
673
+ if hasattr(module, "set_attention_slice"):
674
+ module.set_attention_slice(slice_size.pop())
675
+
676
+ for child in module.children():
677
+ fn_recursive_set_attention_slice(child, slice_size)
678
+
679
+ reversed_slice_size = list(reversed(slice_size))
680
+ for module in self.children():
681
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
682
+
683
+ def _set_gradient_checkpointing(self, module, value=False):
684
+ if hasattr(module, "gradient_checkpointing"):
685
+ module.gradient_checkpointing = value
686
+
687
+ def forward(
688
+ self,
689
+ sample: torch.FloatTensor,
690
+ timestep: Union[torch.Tensor, float, int],
691
+ encoder_hidden_states: torch.Tensor,
692
+ class_labels: Optional[torch.Tensor] = None,
693
+ timestep_cond: Optional[torch.Tensor] = None,
694
+ attention_mask: Optional[torch.Tensor] = None,
695
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
696
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
697
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
698
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
699
+ encoder_attention_mask: Optional[torch.Tensor] = None,
700
+ return_dict: bool = True,
701
+ # additional
702
+ first_frame_latents: Optional[torch.Tensor] = None,
703
+ frame_stride: Optional[Union[torch.Tensor, float, int]] = None,
704
+ ) -> Union[UNet2DConditionOutput, Tuple]:
705
+ # reshape video data
706
+ assert sample.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={sample.dim()}."
707
+ video_length = sample.shape[2]
708
+
709
+ if first_frame_latents is not None:
710
+ assert self.config.first_frame_condition_mode != "none", "first_frame_latents is not None, but first_frame_condition_mode is 'none'."
711
+
712
+ if self.config.first_frame_condition_mode != "none":
713
+ sample = torch.cat([first_frame_latents, sample], dim=2)
714
+ video_length += 1
715
+
716
+ # copy conditioning embeddings for cross attention
717
+ if encoder_hidden_states is not None:
718
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
719
+
720
+ sample = rearrange(sample, "b c f h w -> (b f) c h w")
721
+
722
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
723
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
724
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
725
+ # on the fly if necessary.
726
+ default_overall_up_factor = 2**self.num_upsamplers
727
+
728
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
729
+ forward_upsample_size = False
730
+ upsample_size = None
731
+
732
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
733
+ logger.info("Forward upsample size to force interpolation output size.")
734
+ forward_upsample_size = True
735
+
736
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
737
+ # expects mask of shape:
738
+ # [batch, key_tokens]
739
+ # adds singleton query_tokens dimension:
740
+ # [batch, 1, key_tokens]
741
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
742
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
743
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
744
+ if attention_mask is not None:
745
+ # assume that mask is expressed as:
746
+ # (1 = keep, 0 = discard)
747
+ # convert mask into a bias that can be added to attention scores:
748
+ # (keep = +0, discard = -10000.0)
749
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
750
+ attention_mask = attention_mask.unsqueeze(1)
751
+
752
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
753
+ if encoder_attention_mask is not None:
754
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
755
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
756
+
757
+ # 0. center input if necessary
758
+ if self.config.center_input_sample:
759
+ sample = 2 * sample - 1.0
760
+
761
+ # 1. time
762
+ timesteps = timestep
763
+ if not torch.is_tensor(timesteps):
764
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
765
+ # This would be a good case for the `match` statement (Python 3.10+)
766
+ is_mps = sample.device.type == "mps"
767
+ if isinstance(timestep, float):
768
+ dtype = torch.float32 if is_mps else torch.float64
769
+ else:
770
+ dtype = torch.int32 if is_mps else torch.int64
771
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
772
+ elif len(timesteps.shape) == 0:
773
+ timesteps = timesteps[None].to(sample.device)
774
+
775
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
776
+ timesteps = timesteps.expand(sample.shape[0])
777
+
778
+ t_emb = self.time_proj(timesteps)
779
+
780
+ # `Timesteps` does not contain any weights and will always return f32 tensors
781
+ # but time_embedding might actually be running in fp16. so we need to cast here.
782
+ # there might be better ways to encapsulate this.
783
+ t_emb = t_emb.to(dtype=sample.dtype)
784
+
785
+ emb = self.time_embedding(t_emb, timestep_cond)
786
+
787
+ if self.use_frame_stride_condition:
788
+ if not torch.is_tensor(frame_stride):
789
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
790
+ # This would be a good case for the `match` statement (Python 3.10+)
791
+ is_mps = sample.device.type == "mps"
792
+ if isinstance(timestep, float):
793
+ dtype = torch.float32 if is_mps else torch.float64
794
+ else:
795
+ dtype = torch.int32 if is_mps else torch.int64
796
+ frame_stride = torch.tensor([frame_stride], dtype=dtype, device=sample.device)
797
+ elif len(frame_stride.shape) == 0:
798
+ frame_stride = frame_stride[None].to(sample.device)
799
+
800
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
801
+ frame_stride = frame_stride.expand(sample.shape[0])
802
+
803
+ fs_emb = self.time_proj(frame_stride)
804
+
805
+ # `Timesteps` does not contain any weights and will always return f32 tensors
806
+ # but time_embedding might actually be running in fp16. so we need to cast here.
807
+ # there might be better ways to encapsulate this.
808
+ fs_emb = fs_emb.to(dtype=sample.dtype)
809
+
810
+ fs_emb = self.frame_stride_embedding(fs_emb, timestep_cond)
811
+ emb = emb + fs_emb
812
+
813
+ aug_emb = None
814
+
815
+ if self.class_embedding is not None:
816
+ if class_labels is None:
817
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
818
+
819
+ if self.config.class_embed_type == "timestep":
820
+ class_labels = self.time_proj(class_labels)
821
+
822
+ # `Timesteps` does not contain any weights and will always return f32 tensors
823
+ # there might be better ways to encapsulate this.
824
+ class_labels = class_labels.to(dtype=sample.dtype)
825
+
826
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
827
+
828
+ if self.config.class_embeddings_concat:
829
+ emb = torch.cat([emb, class_emb], dim=-1)
830
+ else:
831
+ emb = emb + class_emb
832
+
833
+ if self.config.addition_embed_type == "text":
834
+ aug_emb = self.add_embedding(encoder_hidden_states)
835
+ elif self.config.addition_embed_type == "text_image":
836
+ # Kandinsky 2.1 - style
837
+ if "image_embeds" not in added_cond_kwargs:
838
+ raise ValueError(
839
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
840
+ )
841
+
842
+ image_embs = added_cond_kwargs.get("image_embeds")
843
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
844
+ aug_emb = self.add_embedding(text_embs, image_embs)
845
+ elif self.config.addition_embed_type == "text_time":
846
+ # SDXL - style
847
+ if "text_embeds" not in added_cond_kwargs:
848
+ raise ValueError(
849
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
850
+ )
851
+ text_embeds = added_cond_kwargs.get("text_embeds")
852
+ if "time_ids" not in added_cond_kwargs:
853
+ raise ValueError(
854
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
855
+ )
856
+ time_ids = added_cond_kwargs.get("time_ids")
857
+ time_embeds = self.add_time_proj(time_ids.flatten())
858
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
859
+
860
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
861
+ add_embeds = add_embeds.to(emb.dtype)
862
+ aug_emb = self.add_embedding(add_embeds)
863
+ elif self.config.addition_embed_type == "image":
864
+ # Kandinsky 2.2 - style
865
+ if "image_embeds" not in added_cond_kwargs:
866
+ raise ValueError(
867
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
868
+ )
869
+ image_embs = added_cond_kwargs.get("image_embeds")
870
+ aug_emb = self.add_embedding(image_embs)
871
+ elif self.config.addition_embed_type == "image_hint":
872
+ # Kandinsky 2.2 - style
873
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
874
+ raise ValueError(
875
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
876
+ )
877
+ image_embs = added_cond_kwargs.get("image_embeds")
878
+ hint = added_cond_kwargs.get("hint")
879
+ aug_emb, hint = self.add_embedding(image_embs, hint)
880
+ sample = torch.cat([sample, hint], dim=1)
881
+
882
+ emb = emb + aug_emb if aug_emb is not None else emb
883
+
884
+ if self.time_embed_act is not None:
885
+ emb = self.time_embed_act(emb)
886
+
887
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
888
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
889
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
890
+ # Kadinsky 2.1 - style
891
+ if "image_embeds" not in added_cond_kwargs:
892
+ raise ValueError(
893
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
894
+ )
895
+
896
+ image_embeds = added_cond_kwargs.get("image_embeds")
897
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
898
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
899
+ # Kandinsky 2.2 - style
900
+ if "image_embeds" not in added_cond_kwargs:
901
+ raise ValueError(
902
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
903
+ )
904
+ image_embeds = added_cond_kwargs.get("image_embeds")
905
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
906
+ # 2. pre-process
907
+ sample = self.conv_in(sample)
908
+
909
+ # 2.5 GLIGEN position net
910
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
911
+ cross_attention_kwargs = cross_attention_kwargs.copy()
912
+ gligen_args = cross_attention_kwargs.pop("gligen")
913
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
914
+
915
+ # 3. down
916
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
917
+
918
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
919
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
920
+
921
+ down_block_res_samples = (sample,)
922
+ for downsample_block in self.down_blocks:
923
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
924
+ # For t2i-adapter CrossAttnDownBlock2D
925
+ additional_residuals = {}
926
+ if is_adapter and len(down_block_additional_residuals) > 0:
927
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
928
+
929
+ sample, res_samples = downsample_block(
930
+ hidden_states=sample,
931
+ temb=emb,
932
+ encoder_hidden_states=encoder_hidden_states,
933
+ attention_mask=attention_mask,
934
+ cross_attention_kwargs=cross_attention_kwargs,
935
+ encoder_attention_mask=encoder_attention_mask,
936
+ first_frame_latents=first_frame_latents,
937
+ **additional_residuals,
938
+ )
939
+ else:
940
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, first_frame_latents=first_frame_latents,)
941
+
942
+ if is_adapter and len(down_block_additional_residuals) > 0:
943
+ sample += down_block_additional_residuals.pop(0)
944
+
945
+ down_block_res_samples += res_samples
946
+
947
+ if is_controlnet:
948
+ new_down_block_res_samples = ()
949
+
950
+ for down_block_res_sample, down_block_additional_residual in zip(
951
+ down_block_res_samples, down_block_additional_residuals
952
+ ):
953
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
954
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
955
+
956
+ down_block_res_samples = new_down_block_res_samples
957
+
958
+ # 4. mid
959
+ if self.mid_block is not None:
960
+ sample = self.mid_block(
961
+ sample,
962
+ emb,
963
+ encoder_hidden_states=encoder_hidden_states,
964
+ attention_mask=attention_mask,
965
+ cross_attention_kwargs=cross_attention_kwargs,
966
+ encoder_attention_mask=encoder_attention_mask,
967
+ # additional
968
+ first_frame_latents=first_frame_latents,
969
+ )
970
+ # To support T2I-Adapter-XL
971
+ if (
972
+ is_adapter
973
+ and len(down_block_additional_residuals) > 0
974
+ and sample.shape == down_block_additional_residuals[0].shape
975
+ ):
976
+ sample += down_block_additional_residuals.pop(0)
977
+
978
+ if is_controlnet:
979
+ sample = sample + mid_block_additional_residual
980
+
981
+ # 5. up
982
+ for i, upsample_block in enumerate(self.up_blocks):
983
+ is_final_block = i == len(self.up_blocks) - 1
984
+
985
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
986
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
987
+
988
+ # if we have not reached the final block and need to forward the
989
+ # upsample size, we do it here
990
+ if not is_final_block and forward_upsample_size:
991
+ upsample_size = down_block_res_samples[-1].shape[2:]
992
+
993
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
994
+ sample = upsample_block(
995
+ hidden_states=sample,
996
+ temb=emb,
997
+ res_hidden_states_tuple=res_samples,
998
+ encoder_hidden_states=encoder_hidden_states,
999
+ cross_attention_kwargs=cross_attention_kwargs,
1000
+ upsample_size=upsample_size,
1001
+ attention_mask=attention_mask,
1002
+ encoder_attention_mask=encoder_attention_mask,
1003
+ first_frame_latents=first_frame_latents,
1004
+ )
1005
+ else:
1006
+ sample = upsample_block(
1007
+ hidden_states=sample,
1008
+ temb=emb,
1009
+ res_hidden_states_tuple=res_samples,
1010
+ upsample_size=upsample_size,
1011
+ scale=lora_scale,
1012
+ first_frame_latents=first_frame_latents,
1013
+ )
1014
+
1015
+ # 6. post-process
1016
+ if self.conv_norm_out:
1017
+ sample = self.conv_norm_out(sample)
1018
+ sample = self.conv_act(sample)
1019
+ sample = self.conv_out(sample)
1020
+
1021
+ sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length)
1022
+ if self.config.first_frame_condition_mode != "none":
1023
+ sample = sample[:, :, 1:, :, :]
1024
+
1025
+ if not return_dict:
1026
+ return (sample,)
1027
+
1028
+ return UNet2DConditionOutput(sample=sample)
1029
+
1030
+ @classmethod
1031
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
1032
+
1033
+ kwargs.pop("low_cpu_mem_usage", False)
1034
+ kwargs.pop("device_map", None)
1035
+
1036
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1037
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1038
+ force_download = kwargs.pop("force_download", False)
1039
+ from_flax = kwargs.pop("from_flax", False)
1040
+ resume_download = kwargs.pop("resume_download", False)
1041
+ proxies = kwargs.pop("proxies", None)
1042
+ output_loading_info = kwargs.pop("output_loading_info", False)
1043
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1044
+ use_auth_token = kwargs.pop("use_auth_token", None)
1045
+ revision = kwargs.pop("revision", None)
1046
+ torch_dtype = kwargs.pop("torch_dtype", None)
1047
+ subfolder = kwargs.pop("subfolder", None)
1048
+ device_map = None
1049
+ max_memory = kwargs.pop("max_memory", None)
1050
+ offload_folder = kwargs.pop("offload_folder", None)
1051
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1052
+ low_cpu_mem_usage = False
1053
+ variant = kwargs.pop("variant", None)
1054
+ use_safetensors = kwargs.pop("use_safetensors", None)
1055
+
1056
+ allow_pickle = False
1057
+ if use_safetensors is None:
1058
+ use_safetensors = True
1059
+ allow_pickle = True
1060
+
1061
+ if low_cpu_mem_usage and not is_accelerate_available():
1062
+ low_cpu_mem_usage = False
1063
+ logger.warning(
1064
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
1065
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
1066
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
1067
+ " install accelerate\n```\n."
1068
+ )
1069
+
1070
+ if device_map is not None and not is_accelerate_available():
1071
+ raise NotImplementedError(
1072
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1073
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1074
+ )
1075
+
1076
+ # Check if we can handle device_map and dispatching the weights
1077
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1078
+ raise NotImplementedError(
1079
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1080
+ " `device_map=None`."
1081
+ )
1082
+
1083
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
1084
+ raise NotImplementedError(
1085
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
1086
+ " `low_cpu_mem_usage=False`."
1087
+ )
1088
+
1089
+ if low_cpu_mem_usage is False and device_map is not None:
1090
+ raise ValueError(
1091
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
1092
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
1093
+ )
1094
+
1095
+ # Load config if we don't provide a configuration
1096
+ config_path = pretrained_model_name_or_path
1097
+
1098
+ user_agent = {
1099
+ "diffusers": __version__,
1100
+ "file_type": "model",
1101
+ "framework": "pytorch",
1102
+ }
1103
+
1104
+ # load config
1105
+ config, unused_kwargs, commit_hash = cls.load_config(
1106
+ config_path,
1107
+ cache_dir=cache_dir,
1108
+ return_unused_kwargs=True,
1109
+ return_commit_hash=True,
1110
+ force_download=force_download,
1111
+ resume_download=resume_download,
1112
+ proxies=proxies,
1113
+ local_files_only=local_files_only,
1114
+ use_auth_token=use_auth_token,
1115
+ revision=revision,
1116
+ subfolder=subfolder,
1117
+ device_map=device_map,
1118
+ max_memory=max_memory,
1119
+ offload_folder=offload_folder,
1120
+ offload_state_dict=offload_state_dict,
1121
+ user_agent=user_agent,
1122
+ **kwargs,
1123
+ )
1124
+
1125
+ # load model
1126
+ model_file = None
1127
+ if from_flax:
1128
+ model_file = _get_model_file(
1129
+ pretrained_model_name_or_path,
1130
+ weights_name=FLAX_WEIGHTS_NAME,
1131
+ cache_dir=cache_dir,
1132
+ force_download=force_download,
1133
+ resume_download=resume_download,
1134
+ proxies=proxies,
1135
+ local_files_only=local_files_only,
1136
+ use_auth_token=use_auth_token,
1137
+ revision=revision,
1138
+ subfolder=subfolder,
1139
+ user_agent=user_agent,
1140
+ commit_hash=commit_hash,
1141
+ )
1142
+ model = cls.from_config(config, **unused_kwargs)
1143
+
1144
+ # Convert the weights
1145
+ from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
1146
+
1147
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
1148
+ else:
1149
+ if use_safetensors:
1150
+ try:
1151
+ model_file = _get_model_file(
1152
+ pretrained_model_name_or_path,
1153
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1154
+ cache_dir=cache_dir,
1155
+ force_download=force_download,
1156
+ resume_download=resume_download,
1157
+ proxies=proxies,
1158
+ local_files_only=local_files_only,
1159
+ use_auth_token=use_auth_token,
1160
+ revision=revision,
1161
+ subfolder=subfolder,
1162
+ user_agent=user_agent,
1163
+ commit_hash=commit_hash,
1164
+ )
1165
+ except IOError as e:
1166
+ if not allow_pickle:
1167
+ raise e
1168
+ pass
1169
+ if model_file is None:
1170
+ model_file = _get_model_file(
1171
+ pretrained_model_name_or_path,
1172
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1173
+ cache_dir=cache_dir,
1174
+ force_download=force_download,
1175
+ resume_download=resume_download,
1176
+ proxies=proxies,
1177
+ local_files_only=local_files_only,
1178
+ use_auth_token=use_auth_token,
1179
+ revision=revision,
1180
+ subfolder=subfolder,
1181
+ user_agent=user_agent,
1182
+ commit_hash=commit_hash,
1183
+ )
1184
+
1185
+ if low_cpu_mem_usage:
1186
+ # Instantiate model with empty weights
1187
+ with accelerate.init_empty_weights():
1188
+ model = cls.from_config(config, **unused_kwargs)
1189
+
1190
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
1191
+ if device_map is None:
1192
+ param_device = "cpu"
1193
+ state_dict = load_state_dict(model_file, variant=variant)
1194
+ model._convert_deprecated_attention_blocks(state_dict)
1195
+ # move the params from meta device to cpu
1196
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
1197
+ if len(missing_keys) > 0:
1198
+ raise ValueError(
1199
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
1200
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
1201
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
1202
+ " those weights or else make sure your checkpoint file is correct."
1203
+ )
1204
+
1205
+ unexpected_keys = load_model_dict_into_meta(
1206
+ model,
1207
+ state_dict,
1208
+ device=param_device,
1209
+ dtype=torch_dtype,
1210
+ model_name_or_path=pretrained_model_name_or_path,
1211
+ )
1212
+
1213
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1214
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1215
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1216
+
1217
+ if len(unexpected_keys) > 0:
1218
+ logger.warn(
1219
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1220
+ )
1221
+
1222
+ else: # else let accelerate handle loading and dispatching.
1223
+ # Load weights and dispatch according to the device_map
1224
+ # by default the device_map is None and the weights are loaded on the CPU
1225
+ try:
1226
+ accelerate.load_checkpoint_and_dispatch(
1227
+ model,
1228
+ model_file,
1229
+ device_map,
1230
+ max_memory=max_memory,
1231
+ offload_folder=offload_folder,
1232
+ offload_state_dict=offload_state_dict,
1233
+ dtype=torch_dtype,
1234
+ )
1235
+ except AttributeError as e:
1236
+ # When using accelerate loading, we do not have the ability to load the state
1237
+ # dict and rename the weight names manually. Additionally, accelerate skips
1238
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
1239
+ # (which look like they should be private variables?), so we can't use the standard hooks
1240
+ # to rename parameters on load. We need to mimic the original weight names so the correct
1241
+ # attributes are available. After we have loaded the weights, we convert the deprecated
1242
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
1243
+ # the weights so we don't have to do this again.
1244
+
1245
+ if "'Attention' object has no attribute" in str(e):
1246
+ logger.warn(
1247
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
1248
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
1249
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
1250
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
1251
+ " please also re-upload it or open a PR on the original repository."
1252
+ )
1253
+ model._temp_convert_self_to_deprecated_attention_blocks()
1254
+ accelerate.load_checkpoint_and_dispatch(
1255
+ model,
1256
+ model_file,
1257
+ device_map,
1258
+ max_memory=max_memory,
1259
+ offload_folder=offload_folder,
1260
+ offload_state_dict=offload_state_dict,
1261
+ dtype=torch_dtype,
1262
+ )
1263
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
1264
+ else:
1265
+ raise e
1266
+
1267
+ loading_info = {
1268
+ "missing_keys": [],
1269
+ "unexpected_keys": [],
1270
+ "mismatched_keys": [],
1271
+ "error_msgs": [],
1272
+ }
1273
+ else:
1274
+ model = cls.from_config(config, **unused_kwargs)
1275
+
1276
+ state_dict = load_state_dict(model_file, variant=variant)
1277
+ model._convert_deprecated_attention_blocks(state_dict)
1278
+
1279
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
1280
+ model,
1281
+ state_dict,
1282
+ model_file,
1283
+ pretrained_model_name_or_path,
1284
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
1285
+ )
1286
+
1287
+ loading_info = {
1288
+ "missing_keys": missing_keys,
1289
+ "unexpected_keys": unexpected_keys,
1290
+ "mismatched_keys": mismatched_keys,
1291
+ "error_msgs": error_msgs,
1292
+ }
1293
+
1294
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1295
+ raise ValueError(
1296
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1297
+ )
1298
+ elif torch_dtype is not None:
1299
+ model = model.to(torch_dtype)
1300
+
1301
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1302
+
1303
+ m, u = loading_info["missing_keys"], loading_info["unexpected_keys"]
1304
+ logger.info(f"### missing keys: {len(m)}; unexpected keys: {len(u)};")
1305
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
1306
+
1307
+ spatial_params = [p.numel() if "conv3ds" not in n and "tempo_attns" not in n else 0 for n, p in model.named_parameters()]
1308
+ tconv_params = [p.numel() if "conv3ds." in n else 0 for n, p in model.named_parameters()]
1309
+ tattn_params = [p.numel() if "tempo_attns." in n else 0 for n, p in model.named_parameters()]
1310
+ tffconv_params = [p.numel() if "first_frame_conv." in n else 0 for n, p in model.named_parameters()]
1311
+ logger.info(f"### First Frame Convolution Layer Parameters: {sum(tffconv_params) / 1e6} M")
1312
+ logger.info(f"### Spatial UNet Parameters: {sum(spatial_params) / 1e6} M")
1313
+ logger.info(f"### Temporal Convolution Module Parameters: {sum(tconv_params) / 1e6} M")
1314
+ logger.info(f"### Temporal Attention Module Parameters: {sum(tattn_params) / 1e6} M")
1315
+
1316
+ # Set model in evaluation mode to deactivate DropOut modules by default
1317
+ model.eval()
1318
+ if output_loading_info:
1319
+ return model, loading_info
1320
+
1321
+ return model
1322
+
1323
+ if __name__ == "__main__":
1324
+ # test
1325
+ from diffusers import AutoencoderKL, DDIMScheduler
1326
+ from transformers import CLIPTextModel, CLIPTokenizer
1327
+ from consisti2v.pipelines.pipeline_animation import AnimationPipeline
1328
+ from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline
1329
+ from consisti2v.utils.util import save_videos_grid
1330
+
1331
+ pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5"
1332
+ prompt = "apply eye makeup"
1333
+ first_frame_path = "/ML-A100/home/weiming/datasets/UCF/frames/v_ApplyEyeMakeup_g01_c01_frame_90.jpg"
1334
+
1335
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer", use_safetensors=True)
1336
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
1337
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae", use_safetensors=True)
1338
+ unet = VideoLDMUNet3DConditionModel.from_pretrained(
1339
+ pretrained_model_path,
1340
+ subfolder="unet",
1341
+ use_safetensors=True
1342
+ )
1343
+
1344
+ noise_scheduler_kwargs = {
1345
+ "num_train_timesteps": 1000,
1346
+ "beta_start": 0.00085,
1347
+ "beta_end": 0.012,
1348
+ "beta_schedule": "linear",
1349
+ "steps_offset": 1,
1350
+ "clip_sample": False,
1351
+ }
1352
+ noise_scheduler = DDIMScheduler(**noise_scheduler_kwargs)
1353
+ # latent = torch.randn(1, 4, 8, 64, 64).to("cuda")
1354
+ # text_embedding = torch.randn(1, 77, 768).to("cuda")
1355
+ # timestep = torch.randint(0, 1000, (1,)).to("cuda").squeeze(0)
1356
+ # output = unet(latent, timestep, text_embedding)
1357
+
1358
+ pipeline = ConditionalAnimationPipeline(
1359
+ unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
1360
+ ).to("cuda")
1361
+ sample = pipeline(
1362
+ prompt,
1363
+ num_inference_steps = 25,
1364
+ guidance_scale = 8.,
1365
+ video_length = 8,
1366
+ height = 256,
1367
+ width = 256,
1368
+ first_frame_paths = first_frame_path,
1369
+ ).videos
1370
+ print(sample.shape)
1371
+ save_videos_grid(sample, f"samples/videoldm.gif")
src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet_blocks.py ADDED
@@ -0,0 +1,1159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Tuple, Any
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+ from einops.layers.torch import Rearrange
8
+ from diffusers.utils import logging
9
+ from diffusers.models.unet_2d_blocks import (
10
+ DownBlock2D,
11
+ UpBlock2D
12
+ )
13
+ from diffusers.models.resnet import (
14
+ ResnetBlock2D,
15
+ Downsample2D,
16
+ Upsample2D,
17
+ )
18
+ from diffusers.models.transformer_2d import Transformer2DModelOutput
19
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
20
+ from diffusers.models.activations import get_activation
21
+ from diffusers.utils import logging, is_torch_version
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+ from .videoldm_transformer_blocks import Transformer2DConditionModel
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_down_block(
35
+ down_block_type,
36
+ num_layers,
37
+ in_channels,
38
+ out_channels,
39
+ temb_channels,
40
+ add_downsample,
41
+ resnet_eps,
42
+ resnet_act_fn,
43
+ transformer_layers_per_block=1,
44
+ num_attention_heads=None,
45
+ resnet_groups=None,
46
+ cross_attention_dim=None,
47
+ downsample_padding=None,
48
+ dual_cross_attention=False,
49
+ use_linear_projection=False,
50
+ only_cross_attention=False,
51
+ upcast_attention=False,
52
+ resnet_time_scale_shift="default",
53
+ attention_type="default",
54
+ resnet_skip_time_act=False,
55
+ resnet_out_scale_factor=1.0,
56
+ cross_attention_norm=None,
57
+ attention_head_dim=None,
58
+ downsample_type=None,
59
+ dropout=0.0,
60
+ # additional
61
+ use_temporal=True,
62
+ augment_temporal_attention=False,
63
+ n_frames=8,
64
+ n_temp_heads=8,
65
+ first_frame_condition_mode="none",
66
+ latent_channels=4,
67
+ rotary_emb=False,
68
+ ):
69
+ # If attn head dim is not defined, we default it to the number of heads
70
+ if attention_head_dim is None:
71
+ logger.warn(
72
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
73
+ )
74
+ attention_head_dim = num_attention_heads
75
+
76
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
77
+ if down_block_type == "DownBlock2D":
78
+ return VideoLDMDownBlock(
79
+ num_layers=num_layers,
80
+ in_channels=in_channels,
81
+ out_channels=out_channels,
82
+ temb_channels=temb_channels,
83
+ dropout=dropout,
84
+ add_downsample=add_downsample,
85
+ resnet_eps=resnet_eps,
86
+ resnet_act_fn=resnet_act_fn,
87
+ resnet_groups=resnet_groups,
88
+ downsample_padding=downsample_padding,
89
+ resnet_time_scale_shift=resnet_time_scale_shift,
90
+ # additional
91
+ use_temporal=use_temporal,
92
+ n_frames=n_frames,
93
+ first_frame_condition_mode=first_frame_condition_mode,
94
+ latent_channels=latent_channels
95
+ )
96
+ elif down_block_type == "CrossAttnDownBlock2D":
97
+ return VideoLDMCrossAttnDownBlock(
98
+ num_layers=num_layers,
99
+ transformer_layers_per_block=transformer_layers_per_block,
100
+ in_channels=in_channels,
101
+ out_channels=out_channels,
102
+ temb_channels=temb_channels,
103
+ dropout=dropout,
104
+ add_downsample=add_downsample,
105
+ resnet_eps=resnet_eps,
106
+ resnet_act_fn=resnet_act_fn,
107
+ resnet_groups=resnet_groups,
108
+ downsample_padding=downsample_padding,
109
+ cross_attention_dim=cross_attention_dim,
110
+ num_attention_heads=num_attention_heads,
111
+ dual_cross_attention=dual_cross_attention,
112
+ use_linear_projection=use_linear_projection,
113
+ only_cross_attention=only_cross_attention,
114
+ upcast_attention=upcast_attention,
115
+ resnet_time_scale_shift=resnet_time_scale_shift,
116
+ attention_type=attention_type,
117
+ # additional
118
+ use_temporal=use_temporal,
119
+ augment_temporal_attention=augment_temporal_attention,
120
+ n_frames=n_frames,
121
+ n_temp_heads=n_temp_heads,
122
+ first_frame_condition_mode=first_frame_condition_mode,
123
+ latent_channels=latent_channels,
124
+ rotary_emb=rotary_emb,
125
+ )
126
+
127
+ raise ValueError(f'{down_block_type} does not exist.')
128
+
129
+
130
+ def get_up_block(
131
+ up_block_type,
132
+ num_layers,
133
+ in_channels,
134
+ out_channels,
135
+ prev_output_channel,
136
+ temb_channels,
137
+ add_upsample,
138
+ resnet_eps,
139
+ resnet_act_fn,
140
+ transformer_layers_per_block=1,
141
+ num_attention_heads=None,
142
+ resnet_groups=None,
143
+ cross_attention_dim=None,
144
+ dual_cross_attention=False,
145
+ use_linear_projection=False,
146
+ only_cross_attention=False,
147
+ upcast_attention=False,
148
+ resnet_time_scale_shift="default",
149
+ attention_type="default",
150
+ resnet_skip_time_act=False,
151
+ resnet_out_scale_factor=1.0,
152
+ cross_attention_norm=None,
153
+ attention_head_dim=None,
154
+ upsample_type=None,
155
+ dropout=0.0,
156
+ # additional
157
+ use_temporal=True,
158
+ augment_temporal_attention=False,
159
+ n_frames=8,
160
+ n_temp_heads=8,
161
+ first_frame_condition_mode="none",
162
+ latent_channels=4,
163
+ rotary_emb=None,
164
+ ):
165
+ if attention_head_dim is None:
166
+ logger.warn(
167
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
168
+ )
169
+ attention_head_dim = num_attention_heads
170
+
171
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
172
+ if up_block_type == "UpBlock2D":
173
+ return VideoLDMUpBlock(
174
+ num_layers=num_layers,
175
+ in_channels=in_channels,
176
+ out_channels=out_channels,
177
+ prev_output_channel=prev_output_channel,
178
+ temb_channels=temb_channels,
179
+ dropout=dropout,
180
+ add_upsample=add_upsample,
181
+ resnet_eps=resnet_eps,
182
+ resnet_act_fn=resnet_act_fn,
183
+ resnet_groups=resnet_groups,
184
+ resnet_time_scale_shift=resnet_time_scale_shift,
185
+ # additional
186
+ use_temporal=use_temporal,
187
+ n_frames=n_frames,
188
+ first_frame_condition_mode=first_frame_condition_mode,
189
+ latent_channels=latent_channels
190
+ )
191
+ elif up_block_type == 'CrossAttnUpBlock2D':
192
+ return VideoLDMCrossAttnUpBlock(
193
+ num_layers=num_layers,
194
+ transformer_layers_per_block=transformer_layers_per_block,
195
+ in_channels=in_channels,
196
+ out_channels=out_channels,
197
+ prev_output_channel=prev_output_channel,
198
+ temb_channels=temb_channels,
199
+ dropout=dropout,
200
+ add_upsample=add_upsample,
201
+ resnet_eps=resnet_eps,
202
+ resnet_act_fn=resnet_act_fn,
203
+ resnet_groups=resnet_groups,
204
+ cross_attention_dim=cross_attention_dim,
205
+ num_attention_heads=num_attention_heads,
206
+ dual_cross_attention=dual_cross_attention,
207
+ use_linear_projection=use_linear_projection,
208
+ only_cross_attention=only_cross_attention,
209
+ upcast_attention=upcast_attention,
210
+ resnet_time_scale_shift=resnet_time_scale_shift,
211
+ attention_type=attention_type,
212
+ # additional
213
+ use_temporal=use_temporal,
214
+ augment_temporal_attention=augment_temporal_attention,
215
+ n_frames=n_frames,
216
+ n_temp_heads=n_temp_heads,
217
+ first_frame_condition_mode=first_frame_condition_mode,
218
+ latent_channels=latent_channels,
219
+ rotary_emb=rotary_emb,
220
+ )
221
+
222
+ raise ValueError(f'{up_block_type} does not exist.')
223
+
224
+
225
+ class TemporalResnetBlock(nn.Module):
226
+ def __init__(
227
+ self,
228
+ *,
229
+ in_channels,
230
+ out_channels=None,
231
+ dropout=0.0,
232
+ temb_channels=512,
233
+ groups=32,
234
+ groups_out=None,
235
+ pre_norm=True,
236
+ eps=1e-6,
237
+ non_linearity="swish",
238
+ time_embedding_norm="default",
239
+ output_scale_factor=1.0,
240
+ # additional
241
+ n_frames=8,
242
+ ):
243
+ super().__init__()
244
+ self.pre_norm = pre_norm
245
+ self.pre_norm = True
246
+ self.in_channels = in_channels
247
+ out_channels = in_channels if out_channels is None else out_channels
248
+ self.out_channels = out_channels
249
+ self.time_embedding_norm = time_embedding_norm
250
+ self.output_scale_factor = output_scale_factor
251
+
252
+ if groups_out is None:
253
+ groups_out = groups
254
+
255
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
256
+
257
+ self.conv1 = Conv3DLayer(in_channels, out_channels, n_frames=n_frames)
258
+
259
+ if temb_channels is not None:
260
+ if self.time_embedding_norm == "default":
261
+ time_emb_proj_out_channels = out_channels
262
+ elif self.time_embedding_norm == "scale_shift":
263
+ time_emb_proj_out_channels = out_channels * 2
264
+ else:
265
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
266
+
267
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
268
+ else:
269
+ self.time_emb_proj = None
270
+
271
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
272
+
273
+ self.dropout = torch.nn.Dropout(dropout)
274
+ self.conv2 = Conv3DLayer(out_channels, out_channels, n_frames=n_frames)
275
+
276
+ self.nonlinearity = get_activation(non_linearity)
277
+
278
+ self.alpha = nn.Parameter(torch.ones(1))
279
+
280
+ def forward(self, input_tensor, temb=None):
281
+ hidden_states = input_tensor
282
+
283
+ hidden_states = self.norm1(hidden_states)
284
+ hidden_states = self.nonlinearity(hidden_states)
285
+
286
+ hidden_states = self.conv1(hidden_states)
287
+
288
+ if temb is not None:
289
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
290
+
291
+ if temb is not None and self.time_embedding_norm == "default":
292
+ hidden_states = hidden_states + temb
293
+
294
+ hidden_states = self.norm2(hidden_states)
295
+
296
+ if temb is not None and self.time_embedding_norm == "scale_shift":
297
+ scale, shift = torch.chunk(temb, 2, dim=1)
298
+ hidden_states = hidden_states * (1 + scale) + shift
299
+
300
+ hidden_states = self.nonlinearity(hidden_states)
301
+
302
+ hidden_states = self.dropout(hidden_states)
303
+ hidden_states = self.conv2(hidden_states)
304
+
305
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
306
+
307
+ # weighted sum between spatial and temporal features
308
+ with torch.no_grad():
309
+ self.alpha.clamp_(0, 1)
310
+
311
+ output_tensor = self.alpha * input_tensor + (1 - self.alpha) * output_tensor
312
+
313
+ return output_tensor
314
+
315
+
316
+ class Conv3DLayer(nn.Conv3d):
317
+ def __init__(self, in_dim, out_dim, n_frames):
318
+ k, p = (3, 1, 1), (1, 0, 0)
319
+ super().__init__(in_channels=in_dim, out_channels=out_dim, kernel_size=k, stride=1, padding=p)
320
+
321
+ self.to_3d = Rearrange('(b t) c h w -> b c t h w', t=n_frames)
322
+ self.to_2d = Rearrange('b c t h w -> (b t) c h w')
323
+
324
+ def forward(self, x):
325
+ h = self.to_3d(x)
326
+ h = super().forward(h)
327
+ out = self.to_2d(h)
328
+ return out
329
+
330
+
331
+ class IdentityLayer(nn.Identity):
332
+ def __init__(self, return_trans2d_output, *args, **kwargs):
333
+ super().__init__()
334
+ self.return_trans2d_output = return_trans2d_output
335
+
336
+ def forward(self, x, *args, **kwargs):
337
+ if self.return_trans2d_output:
338
+ return Transformer2DModelOutput(sample=x)
339
+ else:
340
+ return x
341
+
342
+
343
+ class VideoLDMCrossAttnDownBlock(nn.Module):
344
+ def __init__(
345
+ self,
346
+ in_channels: int,
347
+ out_channels: int,
348
+ temb_channels: int,
349
+ dropout: float = 0.0,
350
+ num_layers: int = 1,
351
+ transformer_layers_per_block: int = 1,
352
+ resnet_eps: float = 1e-6,
353
+ resnet_time_scale_shift: str = "default",
354
+ resnet_act_fn: str = "swish",
355
+ resnet_groups: int = 32,
356
+ resnet_pre_norm: bool = True,
357
+ num_attention_heads=1,
358
+ cross_attention_dim=1280,
359
+ output_scale_factor=1.0,
360
+ downsample_padding=1,
361
+ add_downsample=True,
362
+ dual_cross_attention=False,
363
+ use_linear_projection=False,
364
+ only_cross_attention=False,
365
+ upcast_attention=False,
366
+ attention_type="default",
367
+ # additional
368
+ use_temporal=True,
369
+ augment_temporal_attention=False,
370
+ n_frames=8,
371
+ n_temp_heads=8,
372
+ first_frame_condition_mode="none",
373
+ latent_channels=4,
374
+ rotary_emb=False,
375
+ ):
376
+ super().__init__()
377
+
378
+ self.use_temporal = use_temporal
379
+
380
+ self.n_frames = n_frames
381
+ self.first_frame_condition_mode = first_frame_condition_mode
382
+ if self.first_frame_condition_mode == "conv2d":
383
+ self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1)
384
+
385
+ resnets = []
386
+ attentions = []
387
+
388
+ self.n_frames = n_frames
389
+ self.n_temp_heads = n_temp_heads
390
+
391
+ self.has_cross_attention = True
392
+ self.num_attention_heads = num_attention_heads
393
+
394
+ for i in range(num_layers):
395
+ in_channels = in_channels if i == 0 else out_channels
396
+ resnets.append(
397
+ ResnetBlock2D(
398
+ in_channels=in_channels,
399
+ out_channels=out_channels,
400
+ temb_channels=temb_channels,
401
+ eps=resnet_eps,
402
+ groups=resnet_groups,
403
+ dropout=dropout,
404
+ time_embedding_norm=resnet_time_scale_shift,
405
+ non_linearity=resnet_act_fn,
406
+ output_scale_factor=output_scale_factor,
407
+ pre_norm=resnet_pre_norm,
408
+ )
409
+ )
410
+ if not dual_cross_attention:
411
+ attentions.append(
412
+ Transformer2DConditionModel(
413
+ num_attention_heads,
414
+ out_channels // num_attention_heads,
415
+ in_channels=out_channels,
416
+ num_layers=transformer_layers_per_block,
417
+ cross_attention_dim=cross_attention_dim,
418
+ norm_num_groups=resnet_groups,
419
+ use_linear_projection=use_linear_projection,
420
+ only_cross_attention=only_cross_attention,
421
+ upcast_attention=upcast_attention,
422
+ attention_type=attention_type,
423
+ # additional
424
+ n_frames=n_frames,
425
+ )
426
+ )
427
+ else:
428
+ attentions.append(
429
+ DualTransformer2DModel(
430
+ num_attention_heads,
431
+ out_channels // num_attention_heads,
432
+ in_channels=out_channels,
433
+ num_layers=1,
434
+ cross_attention_dim=cross_attention_dim,
435
+ norm_num_groups=resnet_groups,
436
+ )
437
+ )
438
+ self.attentions = nn.ModuleList(attentions)
439
+ self.resnets = nn.ModuleList(resnets)
440
+
441
+ if add_downsample:
442
+ self.downsamplers = nn.ModuleList(
443
+ [
444
+ Downsample2D(
445
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
446
+ )
447
+ ]
448
+ )
449
+ else:
450
+ self.downsamplers = None
451
+
452
+ self.gradient_checkpointing = False
453
+
454
+ # >>> Temporal Layers >>>
455
+ conv3ds = []
456
+ tempo_attns = []
457
+
458
+ for i in range(num_layers):
459
+ if self.use_temporal:
460
+ conv3ds.append(
461
+ TemporalResnetBlock(
462
+ in_channels=out_channels,
463
+ out_channels=out_channels,
464
+ n_frames=n_frames,
465
+ )
466
+ )
467
+
468
+ tempo_attns.append(
469
+ Transformer2DConditionModel(
470
+ n_temp_heads,
471
+ out_channels // n_temp_heads,
472
+ in_channels=out_channels,
473
+ num_layers=transformer_layers_per_block,
474
+ cross_attention_dim=cross_attention_dim,
475
+ norm_num_groups=resnet_groups,
476
+ use_linear_projection=use_linear_projection,
477
+ only_cross_attention=only_cross_attention,
478
+ upcast_attention=upcast_attention,
479
+ attention_type=attention_type,
480
+ # additional
481
+ n_frames=n_frames,
482
+ is_temporal=True,
483
+ augment_temporal_attention=augment_temporal_attention,
484
+ rotary_emb=rotary_emb
485
+ )
486
+ )
487
+ else:
488
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
489
+ tempo_attns.append(IdentityLayer(return_trans2d_output=True))
490
+
491
+ self.conv3ds = nn.ModuleList(conv3ds)
492
+ self.tempo_attns = nn.ModuleList(tempo_attns)
493
+ # <<< Temporal Layers <<<
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states: torch.FloatTensor,
498
+ temb: Optional[torch.FloatTensor] = None,
499
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
500
+ attention_mask: Optional[torch.FloatTensor] = None,
501
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
502
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
503
+ # additional
504
+ first_frame_latents=None,
505
+ ):
506
+ condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only")
507
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
508
+ if self.first_frame_condition_mode == "conv2d":
509
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
510
+ hidden_height = hidden_states.shape[3]
511
+ first_frame_height = first_frame_latents.shape[3]
512
+ downsample_ratio = hidden_height / first_frame_height
513
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
514
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
515
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
516
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
517
+
518
+ output_states = ()
519
+
520
+ for resnet, conv3d, attn, tempo_attn in zip(self.resnets, self.conv3ds, self.attentions, self.tempo_attns):
521
+
522
+ hidden_states = resnet(hidden_states, temb)
523
+ hidden_states = conv3d(hidden_states)
524
+ hidden_states = attn(
525
+ hidden_states,
526
+ encoder_hidden_states=encoder_hidden_states,
527
+ cross_attention_kwargs=cross_attention_kwargs,
528
+ condition_on_first_frame=condition_on_first_frame,
529
+ ).sample
530
+ hidden_states = tempo_attn(
531
+ hidden_states,
532
+ encoder_hidden_states=encoder_hidden_states,
533
+ cross_attention_kwargs=cross_attention_kwargs,
534
+ condition_on_first_frame=False,
535
+ ).sample
536
+
537
+ output_states += (hidden_states,)
538
+
539
+ if self.downsamplers is not None:
540
+ for downsampler in self.downsamplers:
541
+ hidden_states = downsampler(hidden_states)
542
+
543
+ output_states += (hidden_states,)
544
+
545
+ return hidden_states, output_states
546
+
547
+
548
+ class VideoLDMCrossAttnUpBlock(nn.Module):
549
+ def __init__(
550
+ self,
551
+ in_channels: int,
552
+ out_channels: int,
553
+ prev_output_channel: int,
554
+ temb_channels: int,
555
+ dropout: float = 0.0,
556
+ num_layers: int = 1,
557
+ transformer_layers_per_block: int = 1,
558
+ resnet_eps: float = 1e-6,
559
+ resnet_time_scale_shift: str = "default",
560
+ resnet_act_fn: str = "swish",
561
+ resnet_groups: int = 32,
562
+ resnet_pre_norm: bool = True,
563
+ num_attention_heads=1,
564
+ cross_attention_dim=1280,
565
+ output_scale_factor=1.0,
566
+ add_upsample=True,
567
+ dual_cross_attention=False,
568
+ use_linear_projection=False,
569
+ only_cross_attention=False,
570
+ upcast_attention=False,
571
+ attention_type="default",
572
+ # additional
573
+ use_temporal=True,
574
+ augment_temporal_attention=False,
575
+ n_frames=8,
576
+ n_temp_heads=8,
577
+ first_frame_condition_mode="none",
578
+ latent_channels=4,
579
+ rotary_emb=False,
580
+ ):
581
+ super().__init__()
582
+
583
+ self.use_temporal = use_temporal
584
+
585
+ self.n_frames = n_frames
586
+ self.first_frame_condition_mode = first_frame_condition_mode
587
+ if self.first_frame_condition_mode == "conv2d":
588
+ self.first_frame_conv = nn.Conv2d(latent_channels, prev_output_channel, kernel_size=1)
589
+
590
+ resnets = []
591
+ attentions = []
592
+
593
+ self.n_frames = n_frames
594
+ self.n_temp_heads = n_temp_heads
595
+
596
+ self.has_cross_attention = True
597
+ self.num_attention_heads = num_attention_heads
598
+
599
+ for i in range(num_layers):
600
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
601
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
602
+
603
+ resnets.append(
604
+ ResnetBlock2D(
605
+ in_channels=resnet_in_channels + res_skip_channels,
606
+ out_channels=out_channels,
607
+ temb_channels=temb_channels,
608
+ eps=resnet_eps,
609
+ groups=resnet_groups,
610
+ dropout=dropout,
611
+ time_embedding_norm=resnet_time_scale_shift,
612
+ non_linearity=resnet_act_fn,
613
+ output_scale_factor=output_scale_factor,
614
+ pre_norm=resnet_pre_norm,
615
+ )
616
+ )
617
+ if not dual_cross_attention:
618
+ attentions.append(
619
+ Transformer2DConditionModel(
620
+ num_attention_heads,
621
+ out_channels // num_attention_heads,
622
+ in_channels=out_channels,
623
+ num_layers=transformer_layers_per_block,
624
+ cross_attention_dim=cross_attention_dim,
625
+ norm_num_groups=resnet_groups,
626
+ use_linear_projection=use_linear_projection,
627
+ only_cross_attention=only_cross_attention,
628
+ upcast_attention=upcast_attention,
629
+ attention_type=attention_type,
630
+ # additional
631
+ n_frames=n_frames,
632
+ )
633
+ )
634
+ else:
635
+ attentions.append(
636
+ DualTransformer2DModel(
637
+ num_attention_heads,
638
+ out_channels // num_attention_heads,
639
+ in_channels=out_channels,
640
+ num_layers=1,
641
+ cross_attention_dim=cross_attention_dim,
642
+ norm_num_groups=resnet_groups,
643
+ )
644
+ )
645
+ self.attentions = nn.ModuleList(attentions)
646
+ self.resnets = nn.ModuleList(resnets)
647
+
648
+ if add_upsample:
649
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
650
+ else:
651
+ self.upsamplers = None
652
+
653
+ self.gradient_checkpointing = False
654
+
655
+ # >>> Temporal Layers >>>
656
+ conv3ds = []
657
+ tempo_attns = []
658
+
659
+ for i in range(num_layers):
660
+ if self.use_temporal:
661
+ conv3ds.append(
662
+ TemporalResnetBlock(
663
+ in_channels=out_channels,
664
+ out_channels=out_channels,
665
+ n_frames=n_frames,
666
+ )
667
+ )
668
+
669
+ tempo_attns.append(
670
+ Transformer2DConditionModel(
671
+ n_temp_heads,
672
+ out_channels // n_temp_heads,
673
+ in_channels=out_channels,
674
+ num_layers=transformer_layers_per_block,
675
+ cross_attention_dim=cross_attention_dim,
676
+ norm_num_groups=resnet_groups,
677
+ use_linear_projection=use_linear_projection,
678
+ only_cross_attention=only_cross_attention,
679
+ upcast_attention=upcast_attention,
680
+ attention_type=attention_type,
681
+ # additional
682
+ n_frames=n_frames,
683
+ augment_temporal_attention=augment_temporal_attention,
684
+ is_temporal=True,
685
+ rotary_emb=rotary_emb,
686
+ )
687
+ )
688
+ else:
689
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
690
+ tempo_attns.append(IdentityLayer(return_trans2d_output=True))
691
+
692
+ self.conv3ds = nn.ModuleList(conv3ds)
693
+ self.tempo_attns = nn.ModuleList(tempo_attns)
694
+ # <<< Temporal Layers <<<
695
+
696
+ def forward(
697
+ self,
698
+ hidden_states: torch.FloatTensor,
699
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
700
+ temb: Optional[torch.FloatTensor] = None,
701
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
702
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
703
+ upsample_size: Optional[int] = None,
704
+ attention_mask: Optional[torch.FloatTensor] = None,
705
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
706
+ # additional
707
+ first_frame_latents=None,
708
+ ):
709
+ condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only")
710
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
711
+ if self.first_frame_condition_mode == "conv2d":
712
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
713
+ hidden_height = hidden_states.shape[3]
714
+ first_frame_height = first_frame_latents.shape[3]
715
+ downsample_ratio = hidden_height / first_frame_height
716
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
717
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
718
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
719
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
720
+
721
+ for resnet, conv3d, attn, tempo_attn in zip(self.resnets, self.conv3ds, self.attentions, self.tempo_attns):
722
+
723
+ res_hidden_states = res_hidden_states_tuple[-1]
724
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
725
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
726
+
727
+ hidden_states = resnet(hidden_states, temb)
728
+ hidden_states = conv3d(hidden_states)
729
+ hidden_states = attn(
730
+ hidden_states,
731
+ encoder_hidden_states=encoder_hidden_states,
732
+ cross_attention_kwargs=cross_attention_kwargs,
733
+ condition_on_first_frame=condition_on_first_frame,
734
+ ).sample
735
+ hidden_states = tempo_attn(
736
+ hidden_states,
737
+ encoder_hidden_states=encoder_hidden_states,
738
+ cross_attention_kwargs=cross_attention_kwargs,
739
+ condition_on_first_frame=False,
740
+ ).sample
741
+
742
+ if self.upsamplers is not None:
743
+ for upsampler in self.upsamplers:
744
+ hidden_states = upsampler(hidden_states, upsample_size)
745
+ return hidden_states
746
+
747
+
748
+ class VideoLDMUNetMidBlock2DCrossAttn(nn.Module):
749
+ def __init__(
750
+ self,
751
+ in_channels: int,
752
+ temb_channels: int,
753
+ dropout: float = 0.0,
754
+ num_layers: int = 1,
755
+ transformer_layers_per_block: int = 1,
756
+ resnet_eps: float = 1e-6,
757
+ resnet_time_scale_shift: str = "default",
758
+ resnet_act_fn: str = "swish",
759
+ resnet_groups: int = 32,
760
+ resnet_pre_norm: bool = True,
761
+ num_attention_heads=1,
762
+ output_scale_factor=1.0,
763
+ cross_attention_dim=1280,
764
+ dual_cross_attention=False,
765
+ use_linear_projection=False,
766
+ upcast_attention=False,
767
+ attention_type="default",
768
+ # additional
769
+ use_temporal=True,
770
+ n_frames: int = 8,
771
+ first_frame_condition_mode="none",
772
+ latent_channels=4,
773
+ ):
774
+ super().__init__()
775
+
776
+ self.use_temporal = use_temporal
777
+
778
+ self.n_frames = n_frames
779
+ self.first_frame_condition_mode = first_frame_condition_mode
780
+ if self.first_frame_condition_mode == "conv2d":
781
+ self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1)
782
+
783
+ self.has_cross_attention = True
784
+ self.num_attention_heads = num_attention_heads
785
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
786
+
787
+ # there is always at least one resnet
788
+ resnets = [
789
+ ResnetBlock2D(
790
+ in_channels=in_channels,
791
+ out_channels=in_channels,
792
+ temb_channels=temb_channels,
793
+ eps=resnet_eps,
794
+ groups=resnet_groups,
795
+ dropout=dropout,
796
+ time_embedding_norm=resnet_time_scale_shift,
797
+ non_linearity=resnet_act_fn,
798
+ output_scale_factor=output_scale_factor,
799
+ pre_norm=resnet_pre_norm,
800
+ )
801
+ ]
802
+ if self.use_temporal:
803
+ conv3ds = [
804
+ TemporalResnetBlock(
805
+ in_channels=in_channels,
806
+ out_channels=in_channels,
807
+ n_frames=n_frames,
808
+ )
809
+ ]
810
+ else:
811
+ conv3ds = [IdentityLayer(return_trans2d_output=False)]
812
+
813
+ attentions = []
814
+
815
+ for _ in range(num_layers):
816
+ if not dual_cross_attention:
817
+ attentions.append(
818
+ Transformer2DConditionModel(
819
+ num_attention_heads,
820
+ in_channels // num_attention_heads,
821
+ in_channels=in_channels,
822
+ num_layers=transformer_layers_per_block,
823
+ cross_attention_dim=cross_attention_dim,
824
+ norm_num_groups=resnet_groups,
825
+ use_linear_projection=use_linear_projection,
826
+ upcast_attention=upcast_attention,
827
+ attention_type=attention_type,
828
+ # additional
829
+ n_frames=n_frames,
830
+ )
831
+ )
832
+ else:
833
+ attentions.append(
834
+ DualTransformer2DModel(
835
+ num_attention_heads,
836
+ in_channels // num_attention_heads,
837
+ in_channels=in_channels,
838
+ num_layers=1,
839
+ cross_attention_dim=cross_attention_dim,
840
+ norm_num_groups=resnet_groups,
841
+ )
842
+ )
843
+ resnets.append(
844
+ ResnetBlock2D(
845
+ in_channels=in_channels,
846
+ out_channels=in_channels,
847
+ temb_channels=temb_channels,
848
+ eps=resnet_eps,
849
+ groups=resnet_groups,
850
+ dropout=dropout,
851
+ time_embedding_norm=resnet_time_scale_shift,
852
+ non_linearity=resnet_act_fn,
853
+ output_scale_factor=output_scale_factor,
854
+ pre_norm=resnet_pre_norm,
855
+ )
856
+ )
857
+ if self.use_temporal:
858
+ conv3ds.append(
859
+ TemporalResnetBlock(
860
+ in_channels=in_channels,
861
+ out_channels=in_channels,
862
+ n_frames=n_frames,
863
+ )
864
+ )
865
+ else:
866
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
867
+
868
+ self.attentions = nn.ModuleList(attentions)
869
+ self.resnets = nn.ModuleList(resnets)
870
+ self.conv3ds = nn.ModuleList(conv3ds)
871
+
872
+ self.gradient_checkpointing = False
873
+
874
+ def forward(
875
+ self,
876
+ hidden_states: torch.FloatTensor,
877
+ temb: Optional[torch.FloatTensor] = None,
878
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
879
+ attention_mask: Optional[torch.FloatTensor] = None,
880
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
881
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
882
+ # additional
883
+ first_frame_latents=None,
884
+ ) -> torch.FloatTensor:
885
+ condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only")
886
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
887
+ if self.first_frame_condition_mode == "conv2d":
888
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
889
+ hidden_height = hidden_states.shape[3]
890
+ first_frame_height = first_frame_latents.shape[3]
891
+ downsample_ratio = hidden_height / first_frame_height
892
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
893
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
894
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
895
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
896
+
897
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
898
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
899
+ hidden_states = self.conv3ds[0](hidden_states)
900
+ for attn, resnet, conv3d in zip(self.attentions, self.resnets[1:], self.conv3ds[1:]):
901
+ if self.training and self.gradient_checkpointing:
902
+
903
+ def create_custom_forward(module, return_dict=None):
904
+ def custom_forward(*inputs):
905
+ if return_dict is not None:
906
+ return module(*inputs, return_dict=return_dict)
907
+ else:
908
+ return module(*inputs)
909
+
910
+ return custom_forward
911
+
912
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
913
+ hidden_states = attn(
914
+ hidden_states,
915
+ encoder_hidden_states=encoder_hidden_states,
916
+ cross_attention_kwargs=cross_attention_kwargs,
917
+ attention_mask=attention_mask,
918
+ encoder_attention_mask=encoder_attention_mask,
919
+ return_dict=False,
920
+ # additional
921
+ condition_on_first_frame=condition_on_first_frame,
922
+ )[0]
923
+ hidden_states = torch.utils.checkpoint.checkpoint(
924
+ create_custom_forward(resnet),
925
+ hidden_states,
926
+ temb,
927
+ **ckpt_kwargs,
928
+ )
929
+ hidden_states = conv3d(hidden_states)
930
+ else:
931
+ hidden_states = attn(
932
+ hidden_states,
933
+ encoder_hidden_states=encoder_hidden_states,
934
+ cross_attention_kwargs=cross_attention_kwargs,
935
+ attention_mask=attention_mask,
936
+ encoder_attention_mask=encoder_attention_mask,
937
+ return_dict=False,
938
+ # additional
939
+ condition_on_first_frame=condition_on_first_frame,
940
+ )[0]
941
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
942
+ hidden_states = conv3d(hidden_states)
943
+
944
+ return hidden_states
945
+
946
+
947
+ class VideoLDMDownBlock(DownBlock2D):
948
+ def __init__(
949
+ self,
950
+ in_channels: int,
951
+ out_channels: int,
952
+ temb_channels: int,
953
+ dropout: float = 0.0,
954
+ num_layers: int = 1,
955
+ resnet_eps: float = 1e-6,
956
+ resnet_time_scale_shift: str = "default",
957
+ resnet_act_fn: str = "swish",
958
+ resnet_groups: int = 32,
959
+ resnet_pre_norm: bool = True,
960
+ output_scale_factor=1.0,
961
+ add_downsample=True,
962
+ downsample_padding=1,
963
+ # additional
964
+ use_temporal=True,
965
+ n_frames: int = 8,
966
+ first_frame_condition_mode="none",
967
+ latent_channels=4,
968
+ ):
969
+ super().__init__(
970
+ in_channels,
971
+ out_channels,
972
+ temb_channels,
973
+ dropout,
974
+ num_layers,
975
+ resnet_eps,
976
+ resnet_time_scale_shift,
977
+ resnet_act_fn,
978
+ resnet_groups,
979
+ resnet_pre_norm,
980
+ output_scale_factor,
981
+ add_downsample,
982
+ downsample_padding,)
983
+
984
+ self.use_temporal = use_temporal
985
+
986
+ self.n_frames = n_frames
987
+ self.first_frame_condition_mode = first_frame_condition_mode
988
+ if self.first_frame_condition_mode == "conv2d":
989
+ self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1)
990
+
991
+ # >>> Temporal Layers >>>
992
+ conv3ds = []
993
+ for i in range(num_layers):
994
+ if self.use_temporal:
995
+ conv3ds.append(
996
+ TemporalResnetBlock(
997
+ in_channels=out_channels,
998
+ out_channels=out_channels,
999
+ n_frames=n_frames,
1000
+ )
1001
+ )
1002
+ else:
1003
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
1004
+ self.conv3ds = nn.ModuleList(conv3ds)
1005
+ # <<< Temporal Layers <<<
1006
+
1007
+ def forward(self, hidden_states, temb=None, scale: float = 1, first_frame_latents=None):
1008
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
1009
+ if self.first_frame_condition_mode == "conv2d":
1010
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
1011
+ hidden_height = hidden_states.shape[3]
1012
+ first_frame_height = first_frame_latents.shape[3]
1013
+ downsample_ratio = hidden_height / first_frame_height
1014
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
1015
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
1016
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
1017
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
1018
+
1019
+ output_states = ()
1020
+
1021
+ for resnet, conv3d in zip(self.resnets, self.conv3ds):
1022
+ if self.training and self.gradient_checkpointing:
1023
+
1024
+ def create_custom_forward(module):
1025
+ def custom_forward(*inputs):
1026
+ return module(*inputs)
1027
+
1028
+ return custom_forward
1029
+
1030
+ if is_torch_version(">=", "1.11.0"):
1031
+ hidden_states = torch.utils.checkpoint.checkpoint(
1032
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1033
+ )
1034
+ else:
1035
+ hidden_states = torch.utils.checkpoint.checkpoint(
1036
+ create_custom_forward(resnet), hidden_states, temb
1037
+ )
1038
+ else:
1039
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1040
+
1041
+ hidden_states = conv3d(hidden_states)
1042
+
1043
+ output_states = output_states + (hidden_states,)
1044
+
1045
+ if self.downsamplers is not None:
1046
+ for downsampler in self.downsamplers:
1047
+ hidden_states = downsampler(hidden_states, scale=scale)
1048
+
1049
+ output_states = output_states + (hidden_states,)
1050
+
1051
+ return hidden_states, output_states
1052
+
1053
+
1054
+ class VideoLDMUpBlock(UpBlock2D):
1055
+ def __init__(
1056
+ self,
1057
+ in_channels: int,
1058
+ prev_output_channel: int,
1059
+ out_channels: int,
1060
+ temb_channels: int,
1061
+ dropout: float = 0.0,
1062
+ num_layers: int = 1,
1063
+ resnet_eps: float = 1e-6,
1064
+ resnet_time_scale_shift: str = "default",
1065
+ resnet_act_fn: str = "swish",
1066
+ resnet_groups: int = 32,
1067
+ resnet_pre_norm: bool = True,
1068
+ output_scale_factor=1.0,
1069
+ add_upsample=True,
1070
+ # additional
1071
+ use_temporal=True,
1072
+ n_frames: int = 8,
1073
+ first_frame_condition_mode="none",
1074
+ latent_channels=4,
1075
+ ):
1076
+ super().__init__(
1077
+ in_channels,
1078
+ prev_output_channel,
1079
+ out_channels,
1080
+ temb_channels,
1081
+ dropout,
1082
+ num_layers,
1083
+ resnet_eps,
1084
+ resnet_time_scale_shift,
1085
+ resnet_act_fn,
1086
+ resnet_groups,
1087
+ resnet_pre_norm,
1088
+ output_scale_factor,
1089
+ add_upsample,
1090
+ )
1091
+
1092
+ self.use_temporal = use_temporal
1093
+
1094
+ self.n_frames = n_frames
1095
+ self.first_frame_condition_mode = first_frame_condition_mode
1096
+ if self.first_frame_condition_mode == "conv2d":
1097
+ self.first_frame_conv = nn.Conv2d(latent_channels, prev_output_channel, kernel_size=1)
1098
+
1099
+ # >>> Temporal Layers >>>
1100
+ conv3ds = []
1101
+ for i in range(num_layers):
1102
+ if self.use_temporal:
1103
+ conv3ds.append(
1104
+ TemporalResnetBlock(
1105
+ in_channels=out_channels,
1106
+ out_channels=out_channels,
1107
+ n_frames=n_frames,
1108
+ )
1109
+ )
1110
+ else:
1111
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
1112
+
1113
+ self.conv3ds = nn.ModuleList(conv3ds)
1114
+ # <<< Temporal Layers <<<
1115
+
1116
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1, first_frame_latents=None):
1117
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
1118
+ if self.first_frame_condition_mode == "conv2d":
1119
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
1120
+ hidden_height = hidden_states.shape[3]
1121
+ first_frame_height = first_frame_latents.shape[3]
1122
+ downsample_ratio = hidden_height / first_frame_height
1123
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
1124
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
1125
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
1126
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
1127
+
1128
+ for resnet, conv3d in zip(self.resnets, self.conv3ds):
1129
+ # pop res hidden states
1130
+ res_hidden_states = res_hidden_states_tuple[-1]
1131
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1132
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1133
+
1134
+ if self.training and self.gradient_checkpointing:
1135
+
1136
+ def create_custom_forward(module):
1137
+ def custom_forward(*inputs):
1138
+ return module(*inputs)
1139
+
1140
+ return custom_forward
1141
+
1142
+ if is_torch_version(">=", "1.11.0"):
1143
+ hidden_states = torch.utils.checkpoint.checkpoint(
1144
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1145
+ )
1146
+ else:
1147
+ hidden_states = torch.utils.checkpoint.checkpoint(
1148
+ create_custom_forward(resnet), hidden_states, temb
1149
+ )
1150
+ else:
1151
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1152
+
1153
+ hidden_states = conv3d(hidden_states)
1154
+
1155
+ if self.upsamplers is not None:
1156
+ for upsampler in self.upsamplers:
1157
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1158
+
1159
+ return hidden_states
src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_autoregress_animation.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+
7
+ import math
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from torchvision import transforms as T
13
+ from PIL import Image
14
+
15
+ from diffusers.utils import is_accelerate_available
16
+ from packaging import version
17
+ from transformers import CLIPTextModel, CLIPTokenizer
18
+
19
+ from diffusers.configuration_utils import FrozenDict
20
+ from diffusers.models import AutoencoderKL
21
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
22
+ from diffusers.schedulers import (
23
+ DDIMScheduler,
24
+ DPMSolverMultistepScheduler,
25
+ EulerAncestralDiscreteScheduler,
26
+ EulerDiscreteScheduler,
27
+ LMSDiscreteScheduler,
28
+ PNDMScheduler,
29
+ )
30
+ from diffusers.utils import deprecate, logging, BaseOutput
31
+
32
+ from einops import rearrange, repeat
33
+
34
+ from ..models.unet import UNet3DConditionModel
35
+ from ..utils.frameinit_utils import freq_mix_3d, get_freq_filter
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ # copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21
41
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
42
+ """
43
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
44
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
45
+ """
46
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
47
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
48
+ # rescale the results from guidance (fixes overexposure)
49
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
50
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
51
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
52
+ return noise_cfg
53
+
54
+
55
+ @dataclass
56
+ class AnimationPipelineOutput(BaseOutput):
57
+ videos: Union[torch.Tensor, np.ndarray]
58
+
59
+
60
+ class AutoregressiveAnimationPipeline(DiffusionPipeline):
61
+ _optional_components = []
62
+
63
+ def __init__(
64
+ self,
65
+ vae: AutoencoderKL,
66
+ text_encoder: CLIPTextModel,
67
+ tokenizer: CLIPTokenizer,
68
+ unet: UNet3DConditionModel,
69
+ scheduler: Union[
70
+ DDIMScheduler,
71
+ PNDMScheduler,
72
+ LMSDiscreteScheduler,
73
+ EulerDiscreteScheduler,
74
+ EulerAncestralDiscreteScheduler,
75
+ DPMSolverMultistepScheduler,
76
+ ],
77
+ ):
78
+ super().__init__()
79
+
80
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
81
+ deprecation_message = (
82
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
83
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
84
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
85
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
86
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
87
+ " file"
88
+ )
89
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
90
+ new_config = dict(scheduler.config)
91
+ new_config["steps_offset"] = 1
92
+ scheduler._internal_dict = FrozenDict(new_config)
93
+
94
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
95
+ deprecation_message = (
96
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
97
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
98
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
99
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
100
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
101
+ )
102
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
103
+ new_config = dict(scheduler.config)
104
+ new_config["clip_sample"] = False
105
+ scheduler._internal_dict = FrozenDict(new_config)
106
+
107
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
108
+ version.parse(unet.config._diffusers_version).base_version
109
+ ) < version.parse("0.9.0.dev0")
110
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
111
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
112
+ deprecation_message = (
113
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
114
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
115
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
116
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
117
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
118
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
119
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
120
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
121
+ " the `unet/config.json` file"
122
+ )
123
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
124
+ new_config = dict(unet.config)
125
+ new_config["sample_size"] = 64
126
+ unet._internal_dict = FrozenDict(new_config)
127
+
128
+ self.register_modules(
129
+ vae=vae,
130
+ text_encoder=text_encoder,
131
+ tokenizer=tokenizer,
132
+ unet=unet,
133
+ scheduler=scheduler,
134
+ )
135
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
136
+
137
+ self.freq_filter = None
138
+
139
+ @torch.no_grad()
140
+ def init_filter(self, video_length, height, width, filter_params):
141
+ # initialize frequency filter for noise reinitialization
142
+ batch_size = 1
143
+ num_channels_latents = self.unet.config.in_channels
144
+ filter_shape = [
145
+ batch_size,
146
+ num_channels_latents,
147
+ video_length,
148
+ height // self.vae_scale_factor,
149
+ width // self.vae_scale_factor
150
+ ]
151
+ # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params)
152
+ self.freq_filter = get_freq_filter(
153
+ filter_shape,
154
+ device=self._execution_device,
155
+ filter_type=filter_params.method,
156
+ n=filter_params.n if filter_params.method=="butterworth" else None,
157
+ d_s=filter_params.d_s,
158
+ d_t=filter_params.d_t
159
+ )
160
+
161
+ def enable_vae_slicing(self):
162
+ self.vae.enable_slicing()
163
+
164
+ def disable_vae_slicing(self):
165
+ self.vae.disable_slicing()
166
+
167
+ def enable_sequential_cpu_offload(self, gpu_id=0):
168
+ if is_accelerate_available():
169
+ from accelerate import cpu_offload
170
+ else:
171
+ raise ImportError("Please install accelerate via `pip install accelerate`")
172
+
173
+ device = torch.device(f"cuda:{gpu_id}")
174
+
175
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
176
+ if cpu_offloaded_model is not None:
177
+ cpu_offload(cpu_offloaded_model, device)
178
+
179
+
180
+ @property
181
+ def _execution_device(self):
182
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
183
+ return self.device
184
+ for module in self.unet.modules():
185
+ if (
186
+ hasattr(module, "_hf_hook")
187
+ and hasattr(module._hf_hook, "execution_device")
188
+ and module._hf_hook.execution_device is not None
189
+ ):
190
+ return torch.device(module._hf_hook.execution_device)
191
+ return self.device
192
+
193
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
194
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
195
+
196
+ text_inputs = self.tokenizer(
197
+ prompt,
198
+ padding="max_length",
199
+ max_length=self.tokenizer.model_max_length,
200
+ truncation=True,
201
+ return_tensors="pt",
202
+ )
203
+ text_input_ids = text_inputs.input_ids
204
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
205
+
206
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
207
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
208
+ logger.warning(
209
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
210
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
211
+ )
212
+
213
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
214
+ attention_mask = text_inputs.attention_mask.to(device)
215
+ else:
216
+ attention_mask = None
217
+
218
+ text_embeddings = self.text_encoder(
219
+ text_input_ids.to(device),
220
+ attention_mask=attention_mask,
221
+ )
222
+ text_embeddings = text_embeddings[0]
223
+
224
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
225
+ bs_embed, seq_len, _ = text_embeddings.shape
226
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
227
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
228
+
229
+ # get unconditional embeddings for classifier free guidance
230
+ if do_classifier_free_guidance is not None:
231
+ uncond_tokens: List[str]
232
+ if negative_prompt is None:
233
+ uncond_tokens = [""] * batch_size
234
+ elif type(prompt) is not type(negative_prompt):
235
+ raise TypeError(
236
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
237
+ f" {type(prompt)}."
238
+ )
239
+ elif isinstance(negative_prompt, str):
240
+ uncond_tokens = [negative_prompt]
241
+ elif batch_size != len(negative_prompt):
242
+ raise ValueError(
243
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
244
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
245
+ " the batch size of `prompt`."
246
+ )
247
+ else:
248
+ uncond_tokens = negative_prompt
249
+
250
+ max_length = text_input_ids.shape[-1]
251
+ uncond_input = self.tokenizer(
252
+ uncond_tokens,
253
+ padding="max_length",
254
+ max_length=max_length,
255
+ truncation=True,
256
+ return_tensors="pt",
257
+ )
258
+
259
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
260
+ attention_mask = uncond_input.attention_mask.to(device)
261
+ else:
262
+ attention_mask = None
263
+
264
+ uncond_embeddings = self.text_encoder(
265
+ uncond_input.input_ids.to(device),
266
+ attention_mask=attention_mask,
267
+ )
268
+ uncond_embeddings = uncond_embeddings[0]
269
+
270
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271
+ seq_len = uncond_embeddings.shape[1]
272
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
274
+
275
+ # For classifier free guidance, we need to do two forward passes.
276
+ # Here we concatenate the unconditional and text embeddings into a single batch
277
+ # to avoid doing two forward passes
278
+ if do_classifier_free_guidance == "text":
279
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
280
+ elif do_classifier_free_guidance == "both":
281
+ text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings])
282
+
283
+ return text_embeddings
284
+
285
+ def decode_latents(self, latents, first_frames=None):
286
+ video_length = latents.shape[2]
287
+ latents = 1 / self.vae.config.scaling_factor * latents
288
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
289
+ # video = self.vae.decode(latents).sample
290
+ video = []
291
+ for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config):
292
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
293
+ video = torch.cat(video)
294
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
295
+
296
+ if first_frames is not None:
297
+ first_frames = first_frames.unsqueeze(2)
298
+ video = torch.cat([first_frames, video], dim=2)
299
+
300
+ video = (video / 2 + 0.5).clamp(0, 1)
301
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
302
+ video = video.cpu().float().numpy()
303
+ return video
304
+
305
+ def prepare_extra_step_kwargs(self, generator, eta):
306
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
307
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
308
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
309
+ # and should be between [0, 1]
310
+
311
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
312
+ extra_step_kwargs = {}
313
+ if accepts_eta:
314
+ extra_step_kwargs["eta"] = eta
315
+
316
+ # check if the scheduler accepts generator
317
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
318
+ if accepts_generator:
319
+ extra_step_kwargs["generator"] = generator
320
+ return extra_step_kwargs
321
+
322
+ def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None):
323
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
324
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
325
+
326
+ if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)):
327
+ raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}")
328
+
329
+ if height % 8 != 0 or width % 8 != 0:
330
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
331
+
332
+ if (callback_steps is None) or (
333
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
334
+ ):
335
+ raise ValueError(
336
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
337
+ f" {type(callback_steps)}."
338
+ )
339
+
340
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0):
341
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
342
+ if isinstance(generator, list) and len(generator) != batch_size:
343
+ raise ValueError(
344
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
345
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
346
+ )
347
+ if latents is None:
348
+ rand_device = "cpu" if device.type == "mps" else device
349
+
350
+ if isinstance(generator, list):
351
+ # shape = shape
352
+ shape = (1,) + shape[1:]
353
+ if noise_sampling_method == "vanilla":
354
+ latents = [
355
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
356
+ for i in range(batch_size)
357
+ ]
358
+ elif noise_sampling_method == "pyoco_mixed":
359
+ base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
360
+ latents = []
361
+ noise_alpha_squared = noise_alpha ** 2
362
+ for i in range(batch_size):
363
+ base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
364
+ ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
365
+ latents.append(base_latent + ind_latent)
366
+ elif noise_sampling_method == "pyoco_progressive":
367
+ latents = []
368
+ noise_alpha_squared = noise_alpha ** 2
369
+ for i in range(batch_size):
370
+ latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
371
+ ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
372
+ for j in range(1, video_length):
373
+ latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :]
374
+ latents.append(latent)
375
+ latents = torch.cat(latents, dim=0).to(device)
376
+ else:
377
+ if noise_sampling_method == "vanilla":
378
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
379
+ elif noise_sampling_method == "pyoco_mixed":
380
+ noise_alpha_squared = noise_alpha ** 2
381
+ base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
382
+ base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
383
+ ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
384
+ latents = base_latents + ind_latents
385
+ elif noise_sampling_method == "pyoco_progressive":
386
+ noise_alpha_squared = noise_alpha ** 2
387
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype)
388
+ ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
389
+ for j in range(1, video_length):
390
+ latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :]
391
+ else:
392
+ if latents.shape != shape:
393
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
394
+ latents = latents.to(device)
395
+
396
+ # scale the initial noise by the standard deviation required by the scheduler
397
+ latents = latents * self.scheduler.init_noise_sigma
398
+ return latents
399
+
400
+ @torch.no_grad()
401
+ def __call__(
402
+ self,
403
+ prompt: Union[str, List[str]],
404
+ video_length: Optional[int],
405
+ height: Optional[int] = None,
406
+ width: Optional[int] = None,
407
+ num_inference_steps: int = 50,
408
+ guidance_scale_txt: float = 7.5,
409
+ guidance_scale_img: float = 2.0,
410
+ negative_prompt: Optional[Union[str, List[str]]] = None,
411
+ num_videos_per_prompt: Optional[int] = 1,
412
+ eta: float = 0.0,
413
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
414
+ latents: Optional[torch.FloatTensor] = None,
415
+ output_type: Optional[str] = "tensor",
416
+ return_dict: bool = True,
417
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
418
+ callback_steps: Optional[int] = 1,
419
+ # additional
420
+ first_frame_paths: Optional[Union[str, List[str]]] = None,
421
+ first_frames: Optional[torch.FloatTensor] = None,
422
+ noise_sampling_method: str = "vanilla",
423
+ noise_alpha: float = 1.0,
424
+ guidance_rescale: float = 0.0,
425
+ frame_stride: Optional[int] = None,
426
+ autoregress_steps: int = 3,
427
+ use_frameinit: bool = False,
428
+ frameinit_noise_level: int = 999,
429
+ **kwargs,
430
+ ):
431
+ if first_frame_paths is not None and first_frames is not None:
432
+ raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.")
433
+ # Default height and width to unet
434
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
435
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
436
+
437
+ # Check inputs. Raise error if not correct
438
+ self.check_inputs(prompt, height, width, callback_steps, first_frame_paths)
439
+
440
+ # Define call parameters
441
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
442
+ batch_size = 1
443
+ if latents is not None:
444
+ batch_size = latents.shape[0]
445
+ if isinstance(prompt, list):
446
+ batch_size = len(prompt)
447
+ first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames
448
+ if first_frame_input is not None:
449
+ assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length"
450
+
451
+ device = self._execution_device
452
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
453
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
454
+ # corresponds to doing no classifier free guidance.
455
+ do_classifier_free_guidance = None
456
+ # two guidance mode: text and text+image
457
+ if guidance_scale_txt > 1.0:
458
+ do_classifier_free_guidance = "text"
459
+ if guidance_scale_img > 1.0:
460
+ do_classifier_free_guidance = "both"
461
+
462
+ # Encode input prompt
463
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
464
+ if negative_prompt is not None:
465
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
466
+ text_embeddings = self._encode_prompt(
467
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
468
+ )
469
+
470
+ # Encode input first frame
471
+ first_frame_latents = None
472
+ if first_frame_paths is not None:
473
+ first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size
474
+ img_transform = T.Compose([
475
+ T.ToTensor(),
476
+ T.Resize(height, antialias=None),
477
+ T.CenterCrop((height, width)),
478
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
479
+ ])
480
+ first_frames = []
481
+ for first_frame_path in first_frame_paths:
482
+ first_frame = Image.open(first_frame_path).convert('RGB')
483
+ first_frame = img_transform(first_frame).unsqueeze(0)
484
+ first_frames.append(first_frame)
485
+ first_frames = torch.cat(first_frames, dim=0)
486
+ if first_frames is not None:
487
+ first_frames = first_frames.to(device, dtype=self.vae.dtype)
488
+ first_frame_latents = self.vae.encode(first_frames).latent_dist
489
+ first_frame_latents = first_frame_latents.sample()
490
+ first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w
491
+ first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
492
+ first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
493
+
494
+ full_video_latent = torch.zeros(batch_size * num_videos_per_prompt, self.unet.config.in_channels, video_length * autoregress_steps - autoregress_steps + 1, height // self.vae_scale_factor, width // self.vae_scale_factor, device=device, dtype=self.vae.dtype)
495
+
496
+ start_idx = 0
497
+ for ar_step in range(autoregress_steps):
498
+ # Prepare timesteps
499
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
500
+ timesteps = self.scheduler.timesteps
501
+
502
+ # Prepare latent variables
503
+ num_channels_latents = self.unet.config.in_channels
504
+ latents = self.prepare_latents(
505
+ batch_size * num_videos_per_prompt,
506
+ num_channels_latents,
507
+ video_length,
508
+ height,
509
+ width,
510
+ text_embeddings.dtype,
511
+ device,
512
+ generator,
513
+ latents,
514
+ noise_sampling_method,
515
+ noise_alpha,
516
+ )
517
+ latents_dtype = latents.dtype
518
+
519
+ if use_frameinit:
520
+ current_diffuse_timestep = frameinit_noise_level # diffuse to noise level
521
+ diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep))
522
+ diffuse_timesteps = diffuse_timesteps.long()
523
+ first_frames_static_vid = repeat(first_frame_latents, "b c h w -> b c t h w", t=video_length)
524
+ z_T = self.scheduler.add_noise(
525
+ original_samples=first_frames_static_vid.to(device),
526
+ noise=latents.to(device),
527
+ timesteps=diffuse_timesteps.to(device)
528
+ )
529
+ latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents, LPF=self.freq_filter)
530
+ latents = latents.to(dtype=latents_dtype)
531
+
532
+ if first_frame_latents is not None:
533
+ first_frame_noisy_latent = latents[:, :, 0, :, :]
534
+ latents = latents[:, :, 1:, :, :]
535
+
536
+ # Prepare extra step kwargs.
537
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
538
+
539
+ # Denoising loop
540
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
541
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
542
+ for i, t in enumerate(timesteps):
543
+ # expand the latents if we are doing classifier free guidance
544
+ if do_classifier_free_guidance is None:
545
+ latent_model_input = latents
546
+ elif do_classifier_free_guidance == "text":
547
+ latent_model_input = torch.cat([latents] * 2)
548
+ elif do_classifier_free_guidance == "both":
549
+ latent_model_input = torch.cat([latents] * 3)
550
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
551
+ if first_frame_latents is not None:
552
+ if do_classifier_free_guidance is None:
553
+ first_frame_latents_input = first_frame_latents
554
+ elif do_classifier_free_guidance == "text":
555
+ first_frame_latents_input = torch.cat([first_frame_latents] * 2)
556
+ elif do_classifier_free_guidance == "both":
557
+ first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents])
558
+
559
+ first_frame_latents_input = first_frame_latents_input.unsqueeze(2)
560
+
561
+ # predict the noise residual
562
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype)
563
+ else:
564
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
565
+ # noise_pred = []
566
+ # import pdb
567
+ # pdb.set_trace()
568
+ # for batch_idx in range(latent_model_input.shape[0]):
569
+ # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
570
+ # noise_pred.append(noise_pred_single)
571
+ # noise_pred = torch.cat(noise_pred)
572
+
573
+ # perform guidance
574
+ if do_classifier_free_guidance:
575
+ if do_classifier_free_guidance == "text":
576
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
577
+ noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond)
578
+ elif do_classifier_free_guidance == "both":
579
+ noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3)
580
+ noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img)
581
+
582
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
583
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
584
+ # currently only support text guidance
585
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
586
+
587
+ # compute the previous noisy sample x_t -> x_t-1
588
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
589
+
590
+ # call the callback, if provided
591
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
592
+ progress_bar.update()
593
+ if callback is not None and i % callback_steps == 0:
594
+ callback(i, t, latents)
595
+
596
+ # Post-processing
597
+
598
+ latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2)
599
+ first_frame_latents = latents[:, :, -1, :, :]
600
+ full_video_latent[:, :, start_idx:start_idx + video_length, :, :] = latents
601
+
602
+ latents = None
603
+ start_idx += (video_length - 1)
604
+
605
+ # video = self.decode_latents(latents, first_frames)
606
+ video = self.decode_latents(full_video_latent)
607
+
608
+ # Convert to tensor
609
+ if output_type == "tensor":
610
+ video = torch.from_numpy(video)
611
+
612
+ if not return_dict:
613
+ return video
614
+
615
+ return AnimationPipelineOutput(videos=video)
src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_conditional_animation.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+
7
+ import math
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from torchvision import transforms as T
13
+ from torchvision.transforms import functional as F
14
+ from PIL import Image
15
+
16
+ from diffusers.utils import is_accelerate_available
17
+ from packaging import version
18
+ from transformers import CLIPTextModel, CLIPTokenizer
19
+
20
+ from diffusers.configuration_utils import FrozenDict
21
+ from diffusers.models import AutoencoderKL
22
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
23
+ from diffusers.schedulers import (
24
+ DDIMScheduler,
25
+ DPMSolverMultistepScheduler,
26
+ EulerAncestralDiscreteScheduler,
27
+ EulerDiscreteScheduler,
28
+ LMSDiscreteScheduler,
29
+ PNDMScheduler,
30
+ )
31
+ from diffusers.utils import deprecate, logging, BaseOutput
32
+
33
+ from einops import rearrange, repeat
34
+
35
+ from ..models.videoldm_unet import VideoLDMUNet3DConditionModel
36
+
37
+ from ..utils.frameinit_utils import get_freq_filter, freq_mix_3d
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+ # copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21
43
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
44
+ """
45
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
46
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
47
+ """
48
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
49
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
50
+ # rescale the results from guidance (fixes overexposure)
51
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
52
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
53
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
54
+ return noise_cfg
55
+
56
+ def pan_right(image, num_frames=16, crop_width=256):
57
+ frames = []
58
+ height, width = image.shape[-2:]
59
+
60
+ for i in range(num_frames):
61
+ # Calculate the start position of the crop
62
+ start_x = int((width - crop_width) * (i / num_frames))
63
+ crop = F.crop(image, 0, start_x, height, crop_width)
64
+ frames.append(crop.unsqueeze(0))
65
+
66
+ return torch.cat(frames, dim=0)
67
+
68
+
69
+ def pan_left(image, num_frames=16, crop_width=256):
70
+ frames = []
71
+ height, width = image.shape[-2:]
72
+
73
+ for i in range(num_frames):
74
+ # Start position moves from right to left
75
+ start_x = int((width - crop_width) * (1 - (i / num_frames)))
76
+ crop = F.crop(image, 0, start_x, height, crop_width)
77
+ frames.append(crop.unsqueeze(0))
78
+
79
+ return torch.cat(frames, dim=0)
80
+
81
+
82
+ def zoom_in(image, num_frames=16, crop_width=256, ratio=1.5):
83
+ frames = []
84
+ height, width = image.shape[-2:]
85
+ max_crop_size = min(width, height)
86
+
87
+ for i in range(num_frames):
88
+ # Calculate the size of the crop
89
+ crop_size = max_crop_size - int((max_crop_size - max_crop_size // ratio) * (i / num_frames))
90
+ start_x = (width - crop_size) // 2
91
+ start_y = (height - crop_size) // 2
92
+ crop = F.crop(image, start_y, start_x, crop_size, crop_size)
93
+ resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size
94
+ frames.append(resized_crop.unsqueeze(0))
95
+
96
+ return torch.cat(frames, dim=0)
97
+
98
+
99
+ def zoom_out(image, num_frames=16, crop_width=256, ratio=1.5):
100
+ frames = []
101
+ height, width = image.shape[-2:]
102
+ min_crop_size = min(width, height) // ratio # Starting from a quarter of the size
103
+
104
+ for i in range(num_frames):
105
+ # Calculate the size of the crop
106
+ crop_size = min_crop_size + int((min(width, height) - min_crop_size) * (i / num_frames))
107
+ start_x = (width - crop_size) // 2
108
+ start_y = (height - crop_size) // 2
109
+ crop = F.crop(image, start_y, start_x, crop_size, crop_size)
110
+ resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size
111
+ frames.append(resized_crop.unsqueeze(0))
112
+
113
+ return torch.cat(frames, dim=0)
114
+
115
+
116
+ @dataclass
117
+ class AnimationPipelineOutput(BaseOutput):
118
+ videos: Union[torch.Tensor, np.ndarray]
119
+
120
+
121
+ class ConditionalAnimationPipeline(DiffusionPipeline):
122
+ _optional_components = []
123
+
124
+ def __init__(
125
+ self,
126
+ vae: AutoencoderKL,
127
+ text_encoder: CLIPTextModel,
128
+ tokenizer: CLIPTokenizer,
129
+ unet: VideoLDMUNet3DConditionModel,
130
+ scheduler: Union[
131
+ DDIMScheduler,
132
+ PNDMScheduler,
133
+ LMSDiscreteScheduler,
134
+ EulerDiscreteScheduler,
135
+ EulerAncestralDiscreteScheduler,
136
+ DPMSolverMultistepScheduler,
137
+ ],
138
+ ):
139
+ super().__init__()
140
+
141
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
142
+ deprecation_message = (
143
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
144
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
145
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
146
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
147
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
148
+ " file"
149
+ )
150
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
151
+ new_config = dict(scheduler.config)
152
+ new_config["steps_offset"] = 1
153
+ scheduler._internal_dict = FrozenDict(new_config)
154
+
155
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
156
+ deprecation_message = (
157
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
158
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
159
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
160
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
161
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
162
+ )
163
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
164
+ new_config = dict(scheduler.config)
165
+ new_config["clip_sample"] = False
166
+ scheduler._internal_dict = FrozenDict(new_config)
167
+
168
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
169
+ version.parse(unet.config._diffusers_version).base_version
170
+ ) < version.parse("0.9.0.dev0")
171
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
172
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
173
+ deprecation_message = (
174
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
175
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
176
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
177
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
178
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
179
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
180
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
181
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
182
+ " the `unet/config.json` file"
183
+ )
184
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
185
+ new_config = dict(unet.config)
186
+ new_config["sample_size"] = 64
187
+ unet._internal_dict = FrozenDict(new_config)
188
+
189
+ self.register_modules(
190
+ vae=vae,
191
+ text_encoder=text_encoder,
192
+ tokenizer=tokenizer,
193
+ unet=unet,
194
+ scheduler=scheduler,
195
+ )
196
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
197
+
198
+ self.freq_filter = None
199
+
200
+ @torch.no_grad()
201
+ def init_filter(self, video_length, height, width, filter_params):
202
+ # initialize frequency filter for noise reinitialization
203
+ batch_size = 1
204
+ num_channels_latents = self.unet.config.in_channels
205
+ filter_shape = [
206
+ batch_size,
207
+ num_channels_latents,
208
+ video_length,
209
+ height // self.vae_scale_factor,
210
+ width // self.vae_scale_factor
211
+ ]
212
+ # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params)
213
+ self.freq_filter = get_freq_filter(
214
+ filter_shape,
215
+ device=self._execution_device,
216
+ filter_type=filter_params.method,
217
+ n=filter_params.n if filter_params.method=="butterworth" else None,
218
+ d_s=filter_params.d_s,
219
+ d_t=filter_params.d_t
220
+ )
221
+
222
+ def enable_vae_slicing(self):
223
+ self.vae.enable_slicing()
224
+
225
+ def disable_vae_slicing(self):
226
+ self.vae.disable_slicing()
227
+
228
+ def enable_sequential_cpu_offload(self, gpu_id=0):
229
+ if is_accelerate_available():
230
+ from accelerate import cpu_offload
231
+ else:
232
+ raise ImportError("Please install accelerate via `pip install accelerate`")
233
+
234
+ device = torch.device(f"cuda:{gpu_id}")
235
+
236
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
237
+ if cpu_offloaded_model is not None:
238
+ cpu_offload(cpu_offloaded_model, device)
239
+
240
+
241
+ @property
242
+ def _execution_device(self):
243
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
244
+ return self.device
245
+ for module in self.unet.modules():
246
+ if (
247
+ hasattr(module, "_hf_hook")
248
+ and hasattr(module._hf_hook, "execution_device")
249
+ and module._hf_hook.execution_device is not None
250
+ ):
251
+ return torch.device(module._hf_hook.execution_device)
252
+ return self.device
253
+
254
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
255
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
256
+
257
+ text_inputs = self.tokenizer(
258
+ prompt,
259
+ padding="max_length",
260
+ max_length=self.tokenizer.model_max_length,
261
+ truncation=True,
262
+ return_tensors="pt",
263
+ )
264
+ text_input_ids = text_inputs.input_ids
265
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
266
+
267
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
268
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
269
+ logger.warning(
270
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
271
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
272
+ )
273
+
274
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
275
+ attention_mask = text_inputs.attention_mask.to(device)
276
+ else:
277
+ attention_mask = None
278
+
279
+ text_embeddings = self.text_encoder(
280
+ text_input_ids.to(device),
281
+ attention_mask=attention_mask,
282
+ )
283
+ text_embeddings = text_embeddings[0]
284
+
285
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
286
+ bs_embed, seq_len, _ = text_embeddings.shape
287
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
288
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
289
+
290
+ # get unconditional embeddings for classifier free guidance
291
+ if do_classifier_free_guidance is not None:
292
+ uncond_tokens: List[str]
293
+ if negative_prompt is None:
294
+ uncond_tokens = [""] * batch_size
295
+ elif type(prompt) is not type(negative_prompt):
296
+ raise TypeError(
297
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
298
+ f" {type(prompt)}."
299
+ )
300
+ elif isinstance(negative_prompt, str):
301
+ uncond_tokens = [negative_prompt]
302
+ elif batch_size != len(negative_prompt):
303
+ raise ValueError(
304
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
305
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
306
+ " the batch size of `prompt`."
307
+ )
308
+ else:
309
+ uncond_tokens = negative_prompt
310
+
311
+ max_length = text_input_ids.shape[-1]
312
+ uncond_input = self.tokenizer(
313
+ uncond_tokens,
314
+ padding="max_length",
315
+ max_length=max_length,
316
+ truncation=True,
317
+ return_tensors="pt",
318
+ )
319
+
320
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
321
+ attention_mask = uncond_input.attention_mask.to(device)
322
+ else:
323
+ attention_mask = None
324
+
325
+ uncond_embeddings = self.text_encoder(
326
+ uncond_input.input_ids.to(device),
327
+ attention_mask=attention_mask,
328
+ )
329
+ uncond_embeddings = uncond_embeddings[0]
330
+
331
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
332
+ seq_len = uncond_embeddings.shape[1]
333
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
334
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
335
+
336
+ # For classifier free guidance, we need to do two forward passes.
337
+ # Here we concatenate the unconditional and text embeddings into a single batch
338
+ # to avoid doing two forward passes
339
+ if do_classifier_free_guidance == "text":
340
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
341
+ elif do_classifier_free_guidance == "both":
342
+ text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings])
343
+
344
+ return text_embeddings
345
+
346
+ def decode_latents(self, latents, first_frames=None):
347
+ video_length = latents.shape[2]
348
+ latents = 1 / self.vae.config.scaling_factor * latents
349
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
350
+ # video = self.vae.decode(latents).sample
351
+ video = []
352
+ for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config):
353
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
354
+ video = torch.cat(video)
355
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
356
+
357
+ if first_frames is not None:
358
+ first_frames = first_frames.unsqueeze(2)
359
+ video = torch.cat([first_frames, video], dim=2)
360
+
361
+ video = (video / 2 + 0.5).clamp(0, 1)
362
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
363
+ video = video.cpu().float().numpy()
364
+ return video
365
+
366
+ def prepare_extra_step_kwargs(self, generator, eta):
367
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
368
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
369
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
370
+ # and should be between [0, 1]
371
+
372
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
373
+ extra_step_kwargs = {}
374
+ if accepts_eta:
375
+ extra_step_kwargs["eta"] = eta
376
+
377
+ # check if the scheduler accepts generator
378
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
379
+ if accepts_generator:
380
+ extra_step_kwargs["generator"] = generator
381
+ return extra_step_kwargs
382
+
383
+ def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None):
384
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
385
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
386
+
387
+ if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)):
388
+ raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}")
389
+
390
+ if height % 8 != 0 or width % 8 != 0:
391
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
392
+
393
+ if (callback_steps is None) or (
394
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
395
+ ):
396
+ raise ValueError(
397
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
398
+ f" {type(callback_steps)}."
399
+ )
400
+
401
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0):
402
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
403
+ if isinstance(generator, list) and len(generator) != batch_size:
404
+ raise ValueError(
405
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
406
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
407
+ )
408
+ if latents is None:
409
+ rand_device = "cpu" if device.type == "mps" else device
410
+
411
+ if isinstance(generator, list):
412
+ # shape = shape
413
+ shape = (1,) + shape[1:]
414
+ if noise_sampling_method == "vanilla":
415
+ latents = [
416
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
417
+ for i in range(batch_size)
418
+ ]
419
+ elif noise_sampling_method == "pyoco_mixed":
420
+ base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
421
+ latents = []
422
+ noise_alpha_squared = noise_alpha ** 2
423
+ for i in range(batch_size):
424
+ base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
425
+ ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
426
+ latents.append(base_latent + ind_latent)
427
+ elif noise_sampling_method == "pyoco_progressive":
428
+ latents = []
429
+ noise_alpha_squared = noise_alpha ** 2
430
+ for i in range(batch_size):
431
+ latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
432
+ ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
433
+ for j in range(1, video_length):
434
+ latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :]
435
+ latents.append(latent)
436
+ latents = torch.cat(latents, dim=0).to(device)
437
+ else:
438
+ if noise_sampling_method == "vanilla":
439
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
440
+ elif noise_sampling_method == "pyoco_mixed":
441
+ noise_alpha_squared = noise_alpha ** 2
442
+ base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
443
+ base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
444
+ ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
445
+ latents = base_latents + ind_latents
446
+ elif noise_sampling_method == "pyoco_progressive":
447
+ noise_alpha_squared = noise_alpha ** 2
448
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype)
449
+ ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
450
+ for j in range(1, video_length):
451
+ latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :]
452
+ else:
453
+ if latents.shape != shape:
454
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
455
+ latents = latents.to(device)
456
+
457
+ # scale the initial noise by the standard deviation required by the scheduler
458
+ latents = latents * self.scheduler.init_noise_sigma
459
+ return latents
460
+
461
+ @torch.no_grad()
462
+ def __call__(
463
+ self,
464
+ prompt: Union[str, List[str]],
465
+ video_length: Optional[int],
466
+ height: Optional[int] = None,
467
+ width: Optional[int] = None,
468
+ num_inference_steps: int = 50,
469
+ guidance_scale_txt: float = 7.5,
470
+ guidance_scale_img: float = 2.0,
471
+ negative_prompt: Optional[Union[str, List[str]]] = None,
472
+ num_videos_per_prompt: Optional[int] = 1,
473
+ eta: float = 0.0,
474
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
475
+ latents: Optional[torch.FloatTensor] = None,
476
+ output_type: Optional[str] = "tensor",
477
+ return_dict: bool = True,
478
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
479
+ callback_steps: Optional[int] = 1,
480
+ # additional
481
+ first_frame_paths: Optional[Union[str, List[str]]] = None,
482
+ first_frames: Optional[torch.FloatTensor] = None,
483
+ noise_sampling_method: str = "vanilla",
484
+ noise_alpha: float = 1.0,
485
+ guidance_rescale: float = 0.0,
486
+ frame_stride: Optional[int] = None,
487
+ use_frameinit: bool = False,
488
+ frameinit_noise_level: int = 999,
489
+ camera_motion: str = None,
490
+ **kwargs,
491
+ ):
492
+ if first_frame_paths is not None and first_frames is not None:
493
+ raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.")
494
+ # Default height and width to unet
495
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
496
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
497
+
498
+ # Check inputs. Raise error if not correct
499
+ self.check_inputs(prompt, height, width, callback_steps, first_frame_paths)
500
+
501
+ # Define call parameters
502
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
503
+ batch_size = 1
504
+ if latents is not None:
505
+ batch_size = latents.shape[0]
506
+ if isinstance(prompt, list):
507
+ batch_size = len(prompt)
508
+ first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames
509
+ if first_frame_input is not None:
510
+ assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length"
511
+
512
+ device = self._execution_device
513
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
514
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
515
+ # corresponds to doing no classifier free guidance.
516
+ do_classifier_free_guidance = None
517
+ # two guidance mode: text and text+image
518
+ if guidance_scale_txt > 1.0:
519
+ do_classifier_free_guidance = "text"
520
+ if guidance_scale_img > 1.0:
521
+ do_classifier_free_guidance = "both"
522
+
523
+ # Encode input prompt
524
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
525
+ if negative_prompt is not None:
526
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
527
+ text_embeddings = self._encode_prompt(
528
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
529
+ )
530
+
531
+ # Encode input first frame
532
+ first_frame_latents = None
533
+ if first_frame_paths is not None:
534
+ first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size
535
+ if camera_motion is None:
536
+ img_transform = T.Compose([
537
+ T.ToTensor(),
538
+ T.Resize(height, antialias=None),
539
+ T.CenterCrop((height, width)),
540
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
541
+ ])
542
+ elif camera_motion == "pan_left" or camera_motion == "pan_right":
543
+ img_transform = T.Compose([
544
+ T.ToTensor(),
545
+ T.Resize(height, antialias=None),
546
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
547
+ ])
548
+ elif camera_motion == "zoom_out" or camera_motion == "zoom_in":
549
+ img_transform = T.Compose([
550
+ T.ToTensor(),
551
+ T.Resize(height * 2, antialias=None),
552
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
553
+ ])
554
+
555
+ first_frames = []
556
+ for first_frame_path in first_frame_paths:
557
+ first_frame = Image.open(first_frame_path).convert('RGB')
558
+ first_frame = img_transform(first_frame)
559
+ if camera_motion is not None:
560
+ if camera_motion == "pan_left":
561
+ first_frame = pan_left(first_frame, num_frames=video_length, crop_width=width)
562
+ elif camera_motion == "pan_right":
563
+ first_frame = pan_right(first_frame, num_frames=video_length, crop_width=width)
564
+ elif camera_motion == "zoom_in":
565
+ first_frame = zoom_in(first_frame, num_frames=video_length, crop_width=width)
566
+ elif camera_motion == "zoom_out":
567
+ first_frame = zoom_out(first_frame, num_frames=video_length, crop_width=width)
568
+ else:
569
+ raise NotImplementedError(f"camera_motion: {camera_motion} is not implemented.")
570
+ first_frames.append(first_frame.unsqueeze(0))
571
+ first_frames = torch.cat(first_frames, dim=0)
572
+ if first_frames is not None:
573
+ first_frames = first_frames.to(device, dtype=self.vae.dtype)
574
+ if camera_motion is not None:
575
+ first_frames = rearrange(first_frames, "b f c h w -> (b f) c h w")
576
+ first_frame_latents = self.vae.encode(first_frames).latent_dist
577
+ first_frame_latents = first_frame_latents.sample()
578
+ first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w
579
+ first_frame_static_vid = rearrange(first_frame_latents, "(b f) c h w -> b c f h w", f=video_length if camera_motion is not None else 1)
580
+ first_frame_latents = first_frame_static_vid[:, :, 0, :, :]
581
+ first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
582
+ first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
583
+
584
+ if use_frameinit and camera_motion is None:
585
+ first_frame_static_vid = repeat(first_frame_static_vid, "b c 1 h w -> b c t h w", t=video_length)
586
+
587
+ # self._progress_bar_config = {}
588
+ # vid = self.decode_latents(first_frame_static_vid)
589
+ # vid = torch.from_numpy(vid)
590
+ # from ..utils.util import save_videos_grid
591
+ # save_videos_grid(vid, "samples/debug/camera_motion/first_frame_static_vid.mp4", fps=8)
592
+
593
+ # Prepare timesteps
594
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
595
+ timesteps = self.scheduler.timesteps
596
+
597
+ # Prepare latent variables
598
+ num_channels_latents = self.unet.config.in_channels
599
+ latents = self.prepare_latents(
600
+ batch_size * num_videos_per_prompt,
601
+ num_channels_latents,
602
+ video_length,
603
+ height,
604
+ width,
605
+ text_embeddings.dtype,
606
+ device,
607
+ generator,
608
+ latents,
609
+ noise_sampling_method,
610
+ noise_alpha,
611
+ )
612
+ latents_dtype = latents.dtype
613
+
614
+ if use_frameinit:
615
+ current_diffuse_timestep = frameinit_noise_level # diffuse to t noise level
616
+ diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep))
617
+ diffuse_timesteps = diffuse_timesteps.long()
618
+ z_T = self.scheduler.add_noise(
619
+ original_samples=first_frame_static_vid.to(device),
620
+ noise=latents.to(device),
621
+ timesteps=diffuse_timesteps.to(device)
622
+ )
623
+ latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents.to(dtype=torch.float32), LPF=self.freq_filter)
624
+ latents = latents.to(dtype=latents_dtype)
625
+
626
+ if first_frame_latents is not None:
627
+ first_frame_noisy_latent = latents[:, :, 0, :, :]
628
+ latents = latents[:, :, 1:, :, :]
629
+
630
+ # Prepare extra step kwargs.
631
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
632
+
633
+ # Denoising loop
634
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
635
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
636
+ for i, t in enumerate(timesteps):
637
+ # expand the latents if we are doing classifier free guidance
638
+ if do_classifier_free_guidance is None:
639
+ latent_model_input = latents
640
+ elif do_classifier_free_guidance == "text":
641
+ latent_model_input = torch.cat([latents] * 2)
642
+ elif do_classifier_free_guidance == "both":
643
+ latent_model_input = torch.cat([latents] * 3)
644
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
645
+ if first_frame_latents is not None:
646
+ if do_classifier_free_guidance is None:
647
+ first_frame_latents_input = first_frame_latents
648
+ elif do_classifier_free_guidance == "text":
649
+ first_frame_latents_input = torch.cat([first_frame_latents] * 2)
650
+ elif do_classifier_free_guidance == "both":
651
+ first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents])
652
+
653
+ first_frame_latents_input = first_frame_latents_input.unsqueeze(2)
654
+
655
+ # predict the noise residual
656
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype)
657
+ else:
658
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
659
+
660
+ # perform guidance
661
+ if do_classifier_free_guidance:
662
+ if do_classifier_free_guidance == "text":
663
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
664
+ noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond)
665
+ elif do_classifier_free_guidance == "both":
666
+ noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3)
667
+ noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img)
668
+
669
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
670
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
671
+ # currently only support text guidance
672
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
673
+
674
+ # compute the previous noisy sample x_t -> x_t-1
675
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
676
+
677
+ # call the callback, if provided
678
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
679
+ progress_bar.update()
680
+ if callback is not None and i % callback_steps == 0:
681
+ callback(i, t, latents)
682
+
683
+ # Post-processing
684
+ latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2)
685
+ # video = self.decode_latents(latents, first_frames)
686
+ video = self.decode_latents(latents)
687
+
688
+ # Convert to tensor
689
+ if output_type == "tensor":
690
+ video = torch.from_numpy(video)
691
+
692
+ if not return_dict:
693
+ return video
694
+
695
+ return AnimationPipelineOutput(videos=video)
src/videogen_hub/pipelines/consisti2v/consisti2v/utils/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/consisti2v/utils/frameinit_utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/TianxingWu/FreeInit/blob/master/freeinit_utils.py
2
+ import torch
3
+ import torch.fft as fft
4
+ import math
5
+
6
+
7
+ def freq_mix_3d(x, noise, LPF):
8
+ """
9
+ Noise reinitialization.
10
+
11
+ Args:
12
+ x: diffused latent
13
+ noise: randomly sampled noise
14
+ LPF: low pass filter
15
+ """
16
+ # FFT
17
+ x_freq = fft.fftn(x, dim=(-3, -2, -1))
18
+ x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
19
+ noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
20
+ noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
21
+
22
+ # frequency mix
23
+ HPF = 1 - LPF
24
+ x_freq_low = x_freq * LPF
25
+ noise_freq_high = noise_freq * HPF
26
+ x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
27
+
28
+ # IFFT
29
+ x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
30
+ x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
31
+
32
+ return x_mixed
33
+
34
+
35
+ def get_freq_filter(shape, device, filter_type, n, d_s, d_t):
36
+ """
37
+ Form the frequency filter for noise reinitialization.
38
+
39
+ Args:
40
+ shape: shape of latent (B, C, T, H, W)
41
+ filter_type: type of the freq filter
42
+ n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian
43
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
44
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
45
+ """
46
+ if filter_type == "gaussian":
47
+ return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
48
+ elif filter_type == "ideal":
49
+ return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
50
+ elif filter_type == "box":
51
+ return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
52
+ elif filter_type == "butterworth":
53
+ return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device)
54
+ else:
55
+ raise NotImplementedError
56
+
57
+
58
+ def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25):
59
+ """
60
+ Compute the gaussian low pass filter mask.
61
+
62
+ Args:
63
+ shape: shape of the filter (volume)
64
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
65
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
66
+ """
67
+ T, H, W = shape[-3], shape[-2], shape[-1]
68
+ mask = torch.zeros(shape)
69
+ if d_s==0 or d_t==0:
70
+ return mask
71
+ for t in range(T):
72
+ for h in range(H):
73
+ for w in range(W):
74
+ d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
75
+ mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square)
76
+ return mask
77
+
78
+
79
+ def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25):
80
+ """
81
+ Compute the butterworth low pass filter mask.
82
+
83
+ Args:
84
+ shape: shape of the filter (volume)
85
+ n: order of the filter, larger n ~ ideal, smaller n ~ gaussian
86
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
87
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
88
+ """
89
+ T, H, W = shape[-3], shape[-2], shape[-1]
90
+ mask = torch.zeros(shape)
91
+ if d_s==0 or d_t==0:
92
+ return mask
93
+ for t in range(T):
94
+ for h in range(H):
95
+ for w in range(W):
96
+ d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
97
+ mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n)
98
+ return mask
99
+
100
+
101
+ def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25):
102
+ """
103
+ Compute the ideal low pass filter mask.
104
+
105
+ Args:
106
+ shape: shape of the filter (volume)
107
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
108
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
109
+ """
110
+ T, H, W = shape[-3], shape[-2], shape[-1]
111
+ mask = torch.zeros(shape)
112
+ if d_s==0 or d_t==0:
113
+ return mask
114
+ for t in range(T):
115
+ for h in range(H):
116
+ for w in range(W):
117
+ d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
118
+ mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0
119
+ return mask
120
+
121
+
122
+ def box_low_pass_filter(shape, d_s=0.25, d_t=0.25):
123
+ """
124
+ Compute the ideal low pass filter mask (approximated version).
125
+
126
+ Args:
127
+ shape: shape of the filter (volume)
128
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
129
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
130
+ """
131
+ T, H, W = shape[-3], shape[-2], shape[-1]
132
+ mask = torch.zeros(shape)
133
+ if d_s==0 or d_t==0:
134
+ return mask
135
+
136
+ threshold_s = round(int(H // 2) * d_s)
137
+ threshold_t = round(T // 2 * d_t)
138
+
139
+ cframe, crow, ccol = T // 2, H // 2, W //2
140
+ mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0
141
+
142
+ return mask
src/videogen_hub/pipelines/consisti2v/consisti2v/utils/util.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torchvision
8
+ import torch.distributed as dist
9
+ import wandb
10
+
11
+ from tqdm import tqdm
12
+ from einops import rearrange
13
+
14
+ from torchmetrics.image.fid import _compute_fid
15
+
16
+
17
+ def zero_rank_print(s):
18
+ if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
19
+
20
+
21
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, wandb=False, global_step=0, format="gif"):
22
+ videos = rearrange(videos, "b c t h w -> t b c h w")
23
+ outputs = []
24
+ for x in videos:
25
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
26
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
27
+ if rescale:
28
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
29
+ x = (x * 255).numpy().astype(np.uint8)
30
+ outputs.append(x)
31
+
32
+ if wandb:
33
+ wandb_video = wandb.Video(outputs, fps=fps)
34
+ wandb.log({"val_videos": wandb_video}, step=global_step)
35
+
36
+ os.makedirs(os.path.dirname(path), exist_ok=True)
37
+ if format == "gif":
38
+ imageio.mimsave(path, outputs, fps=fps)
39
+ elif format == "mp4":
40
+ torchvision.io.write_video(path, np.array(outputs), fps=fps, video_codec='h264', options={'crf': '10'})
41
+
42
+ # DDIM Inversion
43
+ @torch.no_grad()
44
+ def init_prompt(prompt, pipeline):
45
+ uncond_input = pipeline.tokenizer(
46
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
47
+ return_tensors="pt"
48
+ )
49
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
50
+ text_input = pipeline.tokenizer(
51
+ [prompt],
52
+ padding="max_length",
53
+ max_length=pipeline.tokenizer.model_max_length,
54
+ truncation=True,
55
+ return_tensors="pt",
56
+ )
57
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
58
+ context = torch.cat([uncond_embeddings, text_embeddings])
59
+
60
+ return context
61
+
62
+
63
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
64
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
65
+ timestep, next_timestep = min(
66
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
67
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
68
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
69
+ beta_prod_t = 1 - alpha_prod_t
70
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
71
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
72
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
73
+ return next_sample
74
+
75
+
76
+ def get_noise_pred_single(latents, t, context, first_frame_latents, frame_stride, unet):
77
+ noise_pred = unet(latents, t, encoder_hidden_states=context, first_frame_latents=first_frame_latents, frame_stride=frame_stride).sample
78
+ return noise_pred
79
+
80
+
81
+ @torch.no_grad()
82
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt, first_frame_latents, frame_stride):
83
+ context = init_prompt(prompt, pipeline)
84
+ uncond_embeddings, cond_embeddings = context.chunk(2)
85
+ all_latent = [latent]
86
+ latent = latent.clone().detach()
87
+ for i in tqdm(range(num_inv_steps)):
88
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
89
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, first_frame_latents, frame_stride, pipeline.unet)
90
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
91
+ all_latent.append(latent)
92
+ return all_latent
93
+
94
+
95
+ @torch.no_grad()
96
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt="", first_frame_latents=None, frame_stride=3):
97
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt, first_frame_latents, frame_stride)
98
+ return ddim_latents
99
+
100
+
101
+ def compute_fid(real_features, fake_features, num_features, device):
102
+ orig_dtype = real_features.dtype
103
+
104
+ mx_num_feats = (num_features, num_features)
105
+ real_features_sum = torch.zeros(num_features).double().to(device)
106
+ real_features_cov_sum = torch.zeros(mx_num_feats).double().to(device)
107
+ real_features_num_samples = torch.tensor(0).long().to(device)
108
+
109
+ fake_features_sum = torch.zeros(num_features).double().to(device)
110
+ fake_features_cov_sum = torch.zeros(mx_num_feats).double().to(device)
111
+ fake_features_num_samples = torch.tensor(0).long().to(device)
112
+
113
+ real_features = real_features.double()
114
+ fake_features = fake_features.double()
115
+
116
+ real_features_sum += real_features.sum(dim=0)
117
+ real_features_cov_sum += real_features.t().mm(real_features)
118
+ real_features_num_samples += real_features.shape[0]
119
+
120
+ fake_features_sum += fake_features.sum(dim=0)
121
+ fake_features_cov_sum += fake_features.t().mm(fake_features)
122
+ fake_features_num_samples += fake_features.shape[0]
123
+
124
+ """Calculate FID score based on accumulated extracted features from the two distributions."""
125
+ if real_features_num_samples < 2 or fake_features_num_samples < 2:
126
+ raise RuntimeError("More than one sample is required for both the real and fake distributed to compute FID")
127
+ mean_real = (real_features_sum / real_features_num_samples).unsqueeze(0)
128
+ mean_fake = (fake_features_sum / fake_features_num_samples).unsqueeze(0)
129
+
130
+ cov_real_num = real_features_cov_sum - real_features_num_samples * mean_real.t().mm(mean_real)
131
+ cov_real = cov_real_num / (real_features_num_samples - 1)
132
+ cov_fake_num = fake_features_cov_sum - fake_features_num_samples * mean_fake.t().mm(mean_fake)
133
+ cov_fake = cov_fake_num / (fake_features_num_samples - 1)
134
+ return _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(orig_dtype)
135
+
136
+
137
+ def compute_inception_score(gen_probs, num_splits=10):
138
+ num_gen = gen_probs.shape[0]
139
+ gen_probs = gen_probs.detach().cpu().numpy()
140
+ scores = []
141
+ np.random.RandomState(42).shuffle(gen_probs)
142
+ for i in range(num_splits):
143
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
144
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
145
+ kl = np.mean(np.sum(kl, axis=1))
146
+ scores.append(np.exp(kl))
147
+ return float(np.mean(scores)), float(np.std(scores))
148
+ # idx = torch.randperm(features.shape[0])
149
+ # features = features[idx]
150
+ # # calculate probs and logits
151
+ # prob = features.softmax(dim=1)
152
+ # log_prob = features.log_softmax(dim=1)
153
+
154
+ # # split into groups
155
+ # prob = prob.chunk(splits, dim=0)
156
+ # log_prob = log_prob.chunk(splits, dim=0)
157
+
158
+ # # calculate score per split
159
+ # mean_prob = [p.mean(dim=0, keepdim=True) for p in prob]
160
+ # kl_ = [p * (log_p - m_p.log()) for p, log_p, m_p in zip(prob, log_prob, mean_prob)]
161
+ # kl_ = [k.sum(dim=1).mean().exp() for k in kl_]
162
+ # kl = torch.stack(kl_)
163
+
164
+ # return mean and std
165
+ # return kl.mean(), kl.std()
src/videogen_hub/pipelines/consisti2v/scripts/__init__.py ADDED
File without changes
src/videogen_hub/pipelines/consisti2v/scripts/animate.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import random
4
+ import os
5
+ import logging
6
+ from omegaconf import OmegaConf
7
+
8
+ import torch
9
+
10
+ import diffusers
11
+ from diffusers import AutoencoderKL, DDIMScheduler
12
+
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel
16
+ from consisti2v.pipelines.pipeline_conditional_animation import (
17
+ ConditionalAnimationPipeline,
18
+ )
19
+ from consisti2v.utils.util import save_videos_grid
20
+ from diffusers.utils.import_utils import is_xformers_available
21
+
22
+
23
+ def main(args, config):
24
+ logging.basicConfig(
25
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
26
+ datefmt="%m/%d/%Y %H:%M:%S",
27
+ level=logging.INFO,
28
+ )
29
+ diffusers.utils.logging.set_verbosity_info()
30
+
31
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
32
+ savedir = f"{config.output_dir}/{config.output_name}-{time_str}"
33
+ os.makedirs(savedir)
34
+
35
+ samples = []
36
+ sample_idx = 0
37
+
38
+ ### >>> create validation pipeline >>> ###
39
+ if config.pipeline_pretrained_path is None:
40
+ noise_scheduler = DDIMScheduler(
41
+ **OmegaConf.to_container(config.noise_scheduler_kwargs)
42
+ )
43
+ tokenizer = CLIPTokenizer.from_pretrained(
44
+ config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True
45
+ )
46
+ text_encoder = CLIPTextModel.from_pretrained(
47
+ config.pretrained_model_path, subfolder="text_encoder"
48
+ )
49
+ vae = AutoencoderKL.from_pretrained(
50
+ config.pretrained_model_path, subfolder="vae", use_safetensors=True
51
+ )
52
+ unet = VideoLDMUNet3DConditionModel.from_pretrained(
53
+ config.pretrained_model_path,
54
+ subfolder="unet",
55
+ variant=config.unet_additional_kwargs["variant"],
56
+ temp_pos_embedding=config.unet_additional_kwargs["temp_pos_embedding"],
57
+ augment_temporal_attention=config.unet_additional_kwargs[
58
+ "augment_temporal_attention"
59
+ ],
60
+ use_temporal=True,
61
+ n_frames=config.sampling_kwargs["n_frames"],
62
+ n_temp_heads=config.unet_additional_kwargs["n_temp_heads"],
63
+ first_frame_condition_mode=config.unet_additional_kwargs[
64
+ "first_frame_condition_mode"
65
+ ],
66
+ use_frame_stride_condition=config.unet_additional_kwargs[
67
+ "use_frame_stride_condition"
68
+ ],
69
+ use_safetensors=True,
70
+ )
71
+
72
+ # 1. unet ckpt
73
+ if config.unet_path is not None:
74
+ if os.path.isdir(config.unet_path):
75
+ unet_dict = VideoLDMUNet3DConditionModel.from_pretrained(
76
+ config.unet_path
77
+ )
78
+ m, u = unet.load_state_dict(unet_dict.state_dict(), strict=False)
79
+ assert len(u) == 0
80
+ del unet_dict
81
+ else:
82
+ checkpoint_dict = torch.load(config.unet_path, map_location="cpu")
83
+ state_dict = (
84
+ checkpoint_dict["state_dict"]
85
+ if "state_dict" in checkpoint_dict
86
+ else checkpoint_dict
87
+ )
88
+ if config.unet_ckpt_prefix is not None:
89
+ state_dict = {
90
+ k.replace(config.unet_ckpt_prefix, ""): v
91
+ for k, v in state_dict.items()
92
+ }
93
+ m, u = unet.load_state_dict(state_dict, strict=False)
94
+ assert len(u) == 0
95
+
96
+ if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2:
97
+ unet.enable_xformers_memory_efficient_attention()
98
+
99
+ pipeline = ConditionalAnimationPipeline(
100
+ vae=vae,
101
+ text_encoder=text_encoder,
102
+ tokenizer=tokenizer,
103
+ unet=unet,
104
+ scheduler=noise_scheduler,
105
+ )
106
+
107
+ else:
108
+ pipeline = ConditionalAnimationPipeline.from_pretrained(
109
+ config.pipeline_pretrained_path
110
+ )
111
+
112
+ pipeline.to("cuda")
113
+
114
+ # (frameinit) initialize frequency filter for noise reinitialization -------------
115
+ if config.frameinit_kwargs.enable:
116
+ pipeline.init_filter(
117
+ width=config.sampling_kwargs.width,
118
+ height=config.sampling_kwargs.height,
119
+ video_length=config.sampling_kwargs.n_frames,
120
+ filter_params=config.frameinit_kwargs.filter_params,
121
+ )
122
+ # -------------------------------------------------------------------------------
123
+ ### <<< create validation pipeline <<< ###
124
+
125
+ if args.prompt is not None:
126
+ prompts = [args.prompt]
127
+ n_prompts = [args.n_prompt]
128
+ first_frame_paths = [args.path_to_first_frame]
129
+ random_seeds = [int(args.seed)] if args.seed != "random" else "random"
130
+ else:
131
+ prompt_config = OmegaConf.load(args.prompt_config)
132
+ prompts = prompt_config.prompts
133
+ n_prompts = (
134
+ list(prompt_config.n_prompts) * len(prompts)
135
+ if len(prompt_config.n_prompts) == 1
136
+ else prompt_config.n_prompts
137
+ )
138
+ first_frame_paths = prompt_config.path_to_first_frames
139
+ random_seeds = prompt_config.seeds
140
+
141
+ if random_seeds == "random":
142
+ random_seeds = [random.randint(0, 1e5) for _ in range(len(prompts))]
143
+ else:
144
+ random_seeds = (
145
+ [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
146
+ )
147
+ random_seeds = (
148
+ random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
149
+ )
150
+
151
+ config.prompt_kwargs = OmegaConf.create(
152
+ {
153
+ "random_seeds": [],
154
+ "prompts": prompts,
155
+ "n_prompts": n_prompts,
156
+ "first_frame_paths": first_frame_paths,
157
+ }
158
+ )
159
+ for prompt_idx, (prompt, n_prompt, first_frame_path, random_seed) in enumerate(
160
+ zip(prompts, n_prompts, first_frame_paths, random_seeds)
161
+ ):
162
+ # manually set random seed for reproduction
163
+ if random_seed != -1:
164
+ torch.manual_seed(random_seed)
165
+ else:
166
+ torch.seed()
167
+ config.prompt_kwargs.random_seeds.append(torch.initial_seed())
168
+
169
+ print(f"current seed: {torch.initial_seed()}")
170
+ print(f"sampling {prompt} ...")
171
+ sample = pipeline(
172
+ prompt,
173
+ negative_prompt=n_prompt,
174
+ first_frame_paths=first_frame_path,
175
+ num_inference_steps=config.sampling_kwargs.steps,
176
+ guidance_scale_txt=config.sampling_kwargs.guidance_scale_txt,
177
+ guidance_scale_img=config.sampling_kwargs.guidance_scale_img,
178
+ width=config.sampling_kwargs.width,
179
+ height=config.sampling_kwargs.height,
180
+ video_length=config.sampling_kwargs.n_frames,
181
+ noise_sampling_method=config.unet_additional_kwargs[
182
+ "noise_sampling_method"
183
+ ],
184
+ noise_alpha=float(config.unet_additional_kwargs["noise_alpha"]),
185
+ eta=config.sampling_kwargs.ddim_eta,
186
+ frame_stride=config.sampling_kwargs.frame_stride,
187
+ guidance_rescale=config.sampling_kwargs.guidance_rescale,
188
+ num_videos_per_prompt=config.sampling_kwargs.num_videos_per_prompt,
189
+ use_frameinit=config.frameinit_kwargs.enable,
190
+ frameinit_noise_level=config.frameinit_kwargs.noise_level,
191
+ camera_motion=config.frameinit_kwargs.camera_motion,
192
+ ).videos
193
+ samples.append(sample)
194
+
195
+ prompt = "-".join((prompt.replace("/", "").split(" ")[:10])).replace(":", "")
196
+ if sample.shape[0] > 1:
197
+ for cnt, samp in enumerate(sample):
198
+ save_videos_grid(
199
+ samp.unsqueeze(0),
200
+ f"{savedir}/sample/{sample_idx}-{cnt + 1}-{prompt}.{args.format}",
201
+ format=args.format,
202
+ )
203
+ else:
204
+ save_videos_grid(
205
+ sample,
206
+ f"{savedir}/sample/{sample_idx}-{prompt}.{args.format}",
207
+ format=args.format,
208
+ )
209
+ print(f"save to {savedir}/sample/{prompt}.{args.format}")
210
+
211
+ sample_idx += 1
212
+
213
+ samples = torch.concat(samples)
214
+ # save_videos_grid(samples, f"{savedir}/sample.{args.format}", n_rows=4, format=args.format)
215
+
216
+ # OmegaConf.save(config, f"{savedir}/config.yaml")
217
+
218
+ # if args.save_model:
219
+ # pipeline.save_pretrained(f"{savedir}/model")
220
+
221
+ return samples
222
+
223
+
224
+ if __name__ == "__main__":
225
+ parser = argparse.ArgumentParser()
226
+ parser.add_argument(
227
+ "--inference_config", type=str, default="configs/inference/inference.yaml"
228
+ )
229
+ parser.add_argument("--prompt", "-p", type=str, default=None)
230
+ parser.add_argument("--n_prompt", "-n", type=str, default="")
231
+ parser.add_argument("--seed", type=str, default="random")
232
+ parser.add_argument("--path_to_first_frame", "-f", type=str, default=None)
233
+ parser.add_argument(
234
+ "--prompt_config", type=str, default="configs/prompts/default.yaml"
235
+ )
236
+ parser.add_argument("--format", type=str, default="mp4", choices=["gif", "mp4"])
237
+ parser.add_argument("--save_model", action="store_true")
238
+ parser.add_argument("optional_args", nargs="*", default=[])
239
+ args = parser.parse_args()
240
+
241
+ config = OmegaConf.load(args.inference_config)
242
+
243
+ if args.optional_args:
244
+ modified_config = OmegaConf.from_dotlist(args.optional_args)
245
+ config = OmegaConf.merge(config, modified_config)
246
+
247
+ main(args, config)