Spaces:
Runtime error
Runtime error
Baraaqasem
commited on
Upload 585 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/videogen_hub/pipelines/__init__.py +0 -0
- src/videogen_hub/pipelines/cogvideo/__init__.py +4 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_pipeline.py +612 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/LICENSE +201 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/Model_License +79 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/__init__.py +0 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/cluster_label2.npy +3 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/coglm_strategy.py +101 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/cogvideo_pipeline.py +1341 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/__init__.py +0 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_cache_model.py +695 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_model.py +543 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/pretrain_cogvideo.py +184 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/requirements.txt +4 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/__init__.py +17 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/cluster_label2.npy +3 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/direct_sr.py +117 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_model.py +225 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_sampling.py +204 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/iterative_sr.py +118 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_model.py +232 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_sampling.py +168 -0
- src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/sr_group.py +49 -0
- src/videogen_hub/pipelines/consisti2v/LICENSE +21 -0
- src/videogen_hub/pipelines/consisti2v/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/configs/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/configs/inference/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/configs/inference/inference.yaml +48 -0
- src/videogen_hub/pipelines/consisti2v/configs/inference/inference_autoregress.yaml +49 -0
- src/videogen_hub/pipelines/consisti2v/configs/prompts/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/configs/prompts/default.yaml +16 -0
- src/videogen_hub/pipelines/consisti2v/configs/training/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/configs/training/training.yaml +92 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/data/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/data/dataset.py +315 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/models/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/models/rotary_embedding.py +280 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_attention.py +809 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_transformer_blocks.py +564 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet.py +1371 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet_blocks.py +1159 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_autoregress_animation.py +615 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_conditional_animation.py +695 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/utils/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/utils/frameinit_utils.py +142 -0
- src/videogen_hub/pipelines/consisti2v/consisti2v/utils/util.py +165 -0
- src/videogen_hub/pipelines/consisti2v/scripts/__init__.py +0 -0
- src/videogen_hub/pipelines/consisti2v/scripts/animate.py +247 -0
src/videogen_hub/pipelines/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/cogvideo/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
sys.path.insert(0, "./src/videogen_hub/pipelines/cogvideo/")
|
4 |
+
sys.path.insert(0, "./src/videogen_hub/pipelines/cogvideo/cogvideo_src")
|
src/videogen_hub/pipelines/cogvideo/cogvideo_pipeline.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from videogen_hub.pipelines.cogvideo.cogvideo_src.cogvideo_pipeline import (
|
2 |
+
InferenceModel_Interpolate,
|
3 |
+
InferenceModel_Sequential,
|
4 |
+
my_filling_sequence,
|
5 |
+
get_masks_and_position_ids_stage1,
|
6 |
+
get_masks_and_position_ids_stage2,
|
7 |
+
my_save_multiple_images,
|
8 |
+
)
|
9 |
+
from videogen_hub.depend.icetk import icetk as tokenizer
|
10 |
+
from videogen_hub.pipelines.cogvideo.cogvideo_src.coglm_strategy import (
|
11 |
+
CoglmStrategy,
|
12 |
+
)
|
13 |
+
from videogen_hub.pipelines.cogvideo.cogvideo_src.sr_pipeline import (
|
14 |
+
DirectSuperResolution,
|
15 |
+
)
|
16 |
+
from SwissArmyTransformer.resources import auto_create
|
17 |
+
import time, logging, sys, os, torch
|
18 |
+
import torch.distributed as dist
|
19 |
+
|
20 |
+
# path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
|
21 |
+
|
22 |
+
|
23 |
+
def pipeline(args, raw_text, height, width, duration):
|
24 |
+
# model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1')
|
25 |
+
# model_stage1.eval()
|
26 |
+
# parent_givan_tokens = process_stage1(model_stage1, raw_text, duration=4.0, video_raw_text=raw_text, video_guidance_text="视频",
|
27 |
+
# image_text_suffix=" 高清摄影",
|
28 |
+
# outputdir=None, batch_size=args.batch_size)
|
29 |
+
|
30 |
+
# process_stage2(model_stage2, raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
|
31 |
+
# video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
|
32 |
+
# outputdir=path,
|
33 |
+
# gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
|
34 |
+
|
35 |
+
assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1
|
36 |
+
rank_id = args.device % args.parallel_size
|
37 |
+
generate_frame_num = args.generate_frame_num
|
38 |
+
|
39 |
+
if args.stage_1 or args.both_stages:
|
40 |
+
model_stage1, args = InferenceModel_Sequential.from_pretrained(
|
41 |
+
args, "cogvideo-stage1"
|
42 |
+
)
|
43 |
+
model_stage1.eval()
|
44 |
+
if args.both_stages:
|
45 |
+
model_stage1 = model_stage1.cpu()
|
46 |
+
|
47 |
+
if args.stage_2 or args.both_stages:
|
48 |
+
model_stage2, args = InferenceModel_Interpolate.from_pretrained(
|
49 |
+
args, "cogvideo-stage2"
|
50 |
+
)
|
51 |
+
model_stage2.eval()
|
52 |
+
if args.both_stages:
|
53 |
+
model_stage2 = model_stage2.cpu()
|
54 |
+
|
55 |
+
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
|
56 |
+
strategy_cogview2 = CoglmStrategy(invalid_slices, temperature=1.0, top_k=16)
|
57 |
+
strategy_cogvideo = CoglmStrategy(
|
58 |
+
invalid_slices,
|
59 |
+
temperature=args.temperature,
|
60 |
+
top_k=args.top_k,
|
61 |
+
temperature2=args.coglm_temperature2,
|
62 |
+
)
|
63 |
+
if not args.stage_1:
|
64 |
+
# from sr_pipeline import DirectSuperResolution
|
65 |
+
dsr_path = auto_create(
|
66 |
+
"cogview2-dsr", path=None
|
67 |
+
) # path=os.getenv('SAT_HOME', '~/.sat_models')
|
68 |
+
dsr = DirectSuperResolution(args, dsr_path, max_bz=12, onCUDA=False)
|
69 |
+
|
70 |
+
def process_stage2(
|
71 |
+
model,
|
72 |
+
seq_text,
|
73 |
+
duration,
|
74 |
+
video_raw_text=None,
|
75 |
+
video_guidance_text="视频",
|
76 |
+
parent_given_tokens=None,
|
77 |
+
conddir=None,
|
78 |
+
outputdir=None,
|
79 |
+
gpu_rank=0,
|
80 |
+
gpu_parallel_size=1,
|
81 |
+
):
|
82 |
+
stage2_starttime = time.time()
|
83 |
+
use_guidance = args.use_guidance_stage2
|
84 |
+
if args.both_stages:
|
85 |
+
move_start_time = time.time()
|
86 |
+
logging.debug("moving stage-2 model to cuda")
|
87 |
+
model = model.cuda()
|
88 |
+
logging.debug(
|
89 |
+
"moving in stage-2 model takes time: {:.2f}".format(
|
90 |
+
time.time() - move_start_time
|
91 |
+
)
|
92 |
+
)
|
93 |
+
|
94 |
+
try:
|
95 |
+
if parent_given_tokens is None:
|
96 |
+
assert conddir is not None
|
97 |
+
parent_given_tokens = torch.load(
|
98 |
+
os.path.join(conddir, "frame_tokens.pt"), map_location="cpu"
|
99 |
+
)
|
100 |
+
sample_num_allgpu = parent_given_tokens.shape[0]
|
101 |
+
sample_num = sample_num_allgpu // gpu_parallel_size
|
102 |
+
assert sample_num * gpu_parallel_size == sample_num_allgpu
|
103 |
+
parent_given_tokens = parent_given_tokens[
|
104 |
+
gpu_rank * sample_num : (gpu_rank + 1) * sample_num
|
105 |
+
]
|
106 |
+
except:
|
107 |
+
logging.critical("No frame_tokens found in interpolation, skip")
|
108 |
+
return False
|
109 |
+
|
110 |
+
# CogVideo Stage2 Generation
|
111 |
+
while (
|
112 |
+
duration >= 0.5
|
113 |
+
): # TODO: You can change the boundary to change the frame rate
|
114 |
+
parent_given_tokens_num = parent_given_tokens.shape[1]
|
115 |
+
generate_batchsize_persample = (parent_given_tokens_num - 1) // 2
|
116 |
+
generate_batchsize_total = generate_batchsize_persample * sample_num
|
117 |
+
total_frames = generate_frame_num
|
118 |
+
frame_len = 400
|
119 |
+
enc_text = tokenizer.encode(seq_text)
|
120 |
+
enc_duration = tokenizer.encode(str(float(duration)) + "秒")
|
121 |
+
seq = (
|
122 |
+
enc_duration
|
123 |
+
+ [tokenizer["<n>"]]
|
124 |
+
+ enc_text
|
125 |
+
+ [tokenizer["<start_of_image>"]]
|
126 |
+
+ [-1] * 400 * generate_frame_num
|
127 |
+
)
|
128 |
+
text_len = len(seq) - frame_len * generate_frame_num - 1
|
129 |
+
|
130 |
+
logging.info(
|
131 |
+
"[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(
|
132 |
+
int(4 / duration), tokenizer.decode(enc_text)
|
133 |
+
)
|
134 |
+
)
|
135 |
+
|
136 |
+
# generation
|
137 |
+
seq = (
|
138 |
+
torch.cuda.LongTensor(seq, device=args.device)
|
139 |
+
.unsqueeze(0)
|
140 |
+
.repeat(generate_batchsize_total, 1)
|
141 |
+
)
|
142 |
+
for sample_i in range(sample_num):
|
143 |
+
for i in range(generate_batchsize_persample):
|
144 |
+
seq[sample_i * generate_batchsize_persample + i][
|
145 |
+
text_len + 1 : text_len + 1 + 400
|
146 |
+
] = parent_given_tokens[sample_i][2 * i]
|
147 |
+
seq[sample_i * generate_batchsize_persample + i][
|
148 |
+
text_len + 1 + 400 : text_len + 1 + 800
|
149 |
+
] = parent_given_tokens[sample_i][2 * i + 1]
|
150 |
+
seq[sample_i * generate_batchsize_persample + i][
|
151 |
+
text_len + 1 + 800 : text_len + 1 + 1200
|
152 |
+
] = parent_given_tokens[sample_i][2 * i + 2]
|
153 |
+
|
154 |
+
if use_guidance:
|
155 |
+
guider_seq = (
|
156 |
+
enc_duration
|
157 |
+
+ [tokenizer["<n>"]]
|
158 |
+
+ tokenizer.encode(video_guidance_text)
|
159 |
+
+ [tokenizer["<start_of_image>"]]
|
160 |
+
+ [-1] * 400 * generate_frame_num
|
161 |
+
)
|
162 |
+
guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
|
163 |
+
guider_seq = (
|
164 |
+
torch.cuda.LongTensor(guider_seq, device=args.device)
|
165 |
+
.unsqueeze(0)
|
166 |
+
.repeat(generate_batchsize_total, 1)
|
167 |
+
)
|
168 |
+
for sample_i in range(sample_num):
|
169 |
+
for i in range(generate_batchsize_persample):
|
170 |
+
guider_seq[sample_i * generate_batchsize_persample + i][
|
171 |
+
text_len + 1 : text_len + 1 + 400
|
172 |
+
] = parent_given_tokens[sample_i][2 * i]
|
173 |
+
guider_seq[sample_i * generate_batchsize_persample + i][
|
174 |
+
text_len + 1 + 400 : text_len + 1 + 800
|
175 |
+
] = parent_given_tokens[sample_i][2 * i + 1]
|
176 |
+
guider_seq[sample_i * generate_batchsize_persample + i][
|
177 |
+
text_len + 1 + 800 : text_len + 1 + 1200
|
178 |
+
] = parent_given_tokens[sample_i][2 * i + 2]
|
179 |
+
video_log_text_attention_weights = 0
|
180 |
+
else:
|
181 |
+
guider_seq = None
|
182 |
+
guider_text_len = 0
|
183 |
+
video_log_text_attention_weights = 1.4
|
184 |
+
|
185 |
+
mbz = args.max_inference_batch_size
|
186 |
+
|
187 |
+
assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
|
188 |
+
output_list = []
|
189 |
+
start_time = time.time()
|
190 |
+
for tim in range(max(generate_batchsize_total // mbz, 1)):
|
191 |
+
input_seq = (
|
192 |
+
seq[: min(generate_batchsize_total, mbz)].clone()
|
193 |
+
if tim == 0
|
194 |
+
else seq[mbz * tim : mbz * (tim + 1)].clone()
|
195 |
+
)
|
196 |
+
guider_seq2 = (
|
197 |
+
(
|
198 |
+
guider_seq[: min(generate_batchsize_total, mbz)].clone()
|
199 |
+
if tim == 0
|
200 |
+
else guider_seq[mbz * tim : mbz * (tim + 1)].clone()
|
201 |
+
)
|
202 |
+
if guider_seq is not None
|
203 |
+
else None
|
204 |
+
)
|
205 |
+
output_list.append(
|
206 |
+
my_filling_sequence(
|
207 |
+
model,
|
208 |
+
args,
|
209 |
+
input_seq,
|
210 |
+
batch_size=min(generate_batchsize_total, mbz),
|
211 |
+
get_masks_and_position_ids=get_masks_and_position_ids_stage2,
|
212 |
+
text_len=text_len,
|
213 |
+
frame_len=frame_len,
|
214 |
+
strategy=strategy_cogview2,
|
215 |
+
strategy2=strategy_cogvideo,
|
216 |
+
log_text_attention_weights=video_log_text_attention_weights,
|
217 |
+
mode_stage1=False,
|
218 |
+
guider_seq=guider_seq2,
|
219 |
+
guider_text_len=guider_text_len,
|
220 |
+
guidance_alpha=args.guidance_alpha,
|
221 |
+
limited_spatial_channel_mem=True,
|
222 |
+
)[0]
|
223 |
+
)
|
224 |
+
logging.info(
|
225 |
+
"Duration {:.2f}, Taken time {:.2f}\n".format(
|
226 |
+
duration, time.time() - start_time
|
227 |
+
)
|
228 |
+
)
|
229 |
+
|
230 |
+
output_tokens = torch.cat(output_list, dim=0)
|
231 |
+
output_tokens = output_tokens[
|
232 |
+
:, text_len + 1 : text_len + 1 + (total_frames) * 400
|
233 |
+
].reshape(sample_num, -1, 400 * total_frames)
|
234 |
+
output_tokens_merge = torch.cat(
|
235 |
+
(
|
236 |
+
output_tokens[:, :, : 1 * 400],
|
237 |
+
output_tokens[:, :, 400 * 3 : 4 * 400],
|
238 |
+
output_tokens[:, :, 400 * 1 : 2 * 400],
|
239 |
+
output_tokens[:, :, 400 * 4 : (total_frames) * 400],
|
240 |
+
),
|
241 |
+
dim=2,
|
242 |
+
).reshape(sample_num, -1, 400)
|
243 |
+
|
244 |
+
output_tokens_merge = torch.cat(
|
245 |
+
(output_tokens_merge, output_tokens[:, -1:, 400 * 2 : 3 * 400]), dim=1
|
246 |
+
)
|
247 |
+
duration /= 2
|
248 |
+
parent_given_tokens = output_tokens_merge
|
249 |
+
|
250 |
+
if args.both_stages:
|
251 |
+
move_start_time = time.time()
|
252 |
+
logging.debug("moving stage 2 model to cpu")
|
253 |
+
model = model.cpu()
|
254 |
+
torch.cuda.empty_cache()
|
255 |
+
logging.debug(
|
256 |
+
"moving out model2 takes time: {:.2f}".format(
|
257 |
+
time.time() - move_start_time
|
258 |
+
)
|
259 |
+
)
|
260 |
+
|
261 |
+
logging.info(
|
262 |
+
"CogVideo Stage2 completed. Taken time {:.2f}\n".format(
|
263 |
+
time.time() - stage2_starttime
|
264 |
+
)
|
265 |
+
)
|
266 |
+
|
267 |
+
# decoding
|
268 |
+
# imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
|
269 |
+
# os.makedirs(output_dir_full_path, exist_ok=True)
|
270 |
+
# my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False)
|
271 |
+
# torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt'))
|
272 |
+
# os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2")
|
273 |
+
|
274 |
+
# direct super-resolution by CogView2
|
275 |
+
logging.info("[Direct super-resolution]")
|
276 |
+
dsr_starttime = time.time()
|
277 |
+
enc_text = tokenizer.encode(seq_text)
|
278 |
+
frame_num_per_sample = parent_given_tokens.shape[1]
|
279 |
+
parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
|
280 |
+
text_seq = (
|
281 |
+
torch.cuda.LongTensor(enc_text, device=args.device)
|
282 |
+
.unsqueeze(0)
|
283 |
+
.repeat(parent_given_tokens_2d.shape[0], 1)
|
284 |
+
)
|
285 |
+
sred_tokens = dsr(text_seq, parent_given_tokens_2d)
|
286 |
+
decoded_sr_videos = []
|
287 |
+
|
288 |
+
for sample_i in range(sample_num):
|
289 |
+
decoded_sr_imgs = []
|
290 |
+
for frame_i in range(frame_num_per_sample):
|
291 |
+
decoded_sr_img = tokenizer.decode(
|
292 |
+
image_ids=sred_tokens[frame_i + sample_i * frame_num_per_sample][
|
293 |
+
-3600:
|
294 |
+
]
|
295 |
+
)
|
296 |
+
decoded_sr_imgs.append(
|
297 |
+
torch.nn.functional.interpolate(
|
298 |
+
decoded_sr_img, size=(height, width)
|
299 |
+
)
|
300 |
+
)
|
301 |
+
decoded_sr_videos.append(decoded_sr_imgs)
|
302 |
+
|
303 |
+
return decoded_sr_videos
|
304 |
+
# for sample_i in range(sample_num):
|
305 |
+
# my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
306 |
+
# os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
|
307 |
+
|
308 |
+
# logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime))
|
309 |
+
|
310 |
+
# return True
|
311 |
+
|
312 |
+
def process_stage1(
|
313 |
+
model,
|
314 |
+
seq_text,
|
315 |
+
duration,
|
316 |
+
video_raw_text=None,
|
317 |
+
video_guidance_text="视频",
|
318 |
+
image_text_suffix="",
|
319 |
+
outputdir=None,
|
320 |
+
batch_size=1,
|
321 |
+
):
|
322 |
+
process_start_time = time.time()
|
323 |
+
use_guide = args.use_guidance_stage1
|
324 |
+
if args.both_stages:
|
325 |
+
move_start_time = time.time()
|
326 |
+
logging.debug("moving stage 1 model to cuda")
|
327 |
+
model = model.cuda()
|
328 |
+
logging.debug(
|
329 |
+
"moving in model1 takes time: {:.2f}".format(
|
330 |
+
time.time() - move_start_time
|
331 |
+
)
|
332 |
+
)
|
333 |
+
|
334 |
+
if video_raw_text is None:
|
335 |
+
video_raw_text = seq_text
|
336 |
+
mbz = (
|
337 |
+
args.stage1_max_inference_batch_size
|
338 |
+
if args.stage1_max_inference_batch_size > 0
|
339 |
+
else args.max_inference_batch_size
|
340 |
+
)
|
341 |
+
assert batch_size < mbz or batch_size % mbz == 0
|
342 |
+
frame_len = 400
|
343 |
+
|
344 |
+
# generate the first frame:
|
345 |
+
enc_text = tokenizer.encode(seq_text + image_text_suffix)
|
346 |
+
seq_1st = (
|
347 |
+
enc_text + [tokenizer["<start_of_image>"]] + [-1] * 400
|
348 |
+
) # IV!! # test local!!! # test randboi!!!
|
349 |
+
logging.info(
|
350 |
+
"[Generating First Frame with CogView2]Raw text: {:s}".format(
|
351 |
+
tokenizer.decode(enc_text)
|
352 |
+
)
|
353 |
+
)
|
354 |
+
text_len_1st = len(seq_1st) - frame_len * 1 - 1
|
355 |
+
|
356 |
+
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
|
357 |
+
output_list_1st = []
|
358 |
+
for tim in range(max(batch_size // mbz, 1)):
|
359 |
+
start_time = time.time()
|
360 |
+
output_list_1st.append(
|
361 |
+
my_filling_sequence(
|
362 |
+
model,
|
363 |
+
args,
|
364 |
+
seq_1st.clone(),
|
365 |
+
batch_size=min(batch_size, mbz),
|
366 |
+
get_masks_and_position_ids=get_masks_and_position_ids_stage1,
|
367 |
+
text_len=text_len_1st,
|
368 |
+
frame_len=frame_len,
|
369 |
+
strategy=strategy_cogview2,
|
370 |
+
strategy2=strategy_cogvideo,
|
371 |
+
log_text_attention_weights=1.4,
|
372 |
+
enforce_no_swin=True,
|
373 |
+
mode_stage1=True,
|
374 |
+
)[0]
|
375 |
+
)
|
376 |
+
logging.info(
|
377 |
+
"[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)
|
378 |
+
)
|
379 |
+
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
380 |
+
given_tokens = output_tokens_1st[
|
381 |
+
:, text_len_1st + 1 : text_len_1st + 401
|
382 |
+
].unsqueeze(
|
383 |
+
1
|
384 |
+
) # given_tokens.shape: [bs, frame_num, 400]
|
385 |
+
|
386 |
+
# generate subsequent frames:
|
387 |
+
total_frames = generate_frame_num
|
388 |
+
enc_duration = tokenizer.encode(str(float(duration)) + "秒")
|
389 |
+
if use_guide:
|
390 |
+
video_raw_text = video_raw_text + " 视频"
|
391 |
+
enc_text_video = tokenizer.encode(video_raw_text)
|
392 |
+
seq = (
|
393 |
+
enc_duration
|
394 |
+
+ [tokenizer["<n>"]]
|
395 |
+
+ enc_text_video
|
396 |
+
+ [tokenizer["<start_of_image>"]]
|
397 |
+
+ [-1] * 400 * generate_frame_num
|
398 |
+
)
|
399 |
+
guider_seq = (
|
400 |
+
enc_duration
|
401 |
+
+ [tokenizer["<n>"]]
|
402 |
+
+ tokenizer.encode(video_guidance_text)
|
403 |
+
+ [tokenizer["<start_of_image>"]]
|
404 |
+
+ [-1] * 400 * generate_frame_num
|
405 |
+
)
|
406 |
+
logging.info(
|
407 |
+
"[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(
|
408 |
+
4 / duration, tokenizer.decode(enc_text_video)
|
409 |
+
)
|
410 |
+
)
|
411 |
+
|
412 |
+
text_len = len(seq) - frame_len * generate_frame_num - 1
|
413 |
+
guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
|
414 |
+
seq = (
|
415 |
+
torch.cuda.LongTensor(seq, device=args.device)
|
416 |
+
.unsqueeze(0)
|
417 |
+
.repeat(batch_size, 1)
|
418 |
+
)
|
419 |
+
guider_seq = (
|
420 |
+
torch.cuda.LongTensor(guider_seq, device=args.device)
|
421 |
+
.unsqueeze(0)
|
422 |
+
.repeat(batch_size, 1)
|
423 |
+
)
|
424 |
+
|
425 |
+
for given_frame_id in range(given_tokens.shape[1]):
|
426 |
+
seq[
|
427 |
+
:,
|
428 |
+
text_len
|
429 |
+
+ 1
|
430 |
+
+ given_frame_id * 400 : text_len
|
431 |
+
+ 1
|
432 |
+
+ (given_frame_id + 1) * 400,
|
433 |
+
] = given_tokens[:, given_frame_id]
|
434 |
+
guider_seq[
|
435 |
+
:,
|
436 |
+
guider_text_len
|
437 |
+
+ 1
|
438 |
+
+ given_frame_id * 400 : guider_text_len
|
439 |
+
+ 1
|
440 |
+
+ (given_frame_id + 1) * 400,
|
441 |
+
] = given_tokens[:, given_frame_id]
|
442 |
+
output_list = []
|
443 |
+
|
444 |
+
if use_guide:
|
445 |
+
video_log_text_attention_weights = 0
|
446 |
+
else:
|
447 |
+
guider_seq = None
|
448 |
+
video_log_text_attention_weights = 1.4
|
449 |
+
|
450 |
+
for tim in range(max(batch_size // mbz, 1)):
|
451 |
+
start_time = time.time()
|
452 |
+
input_seq = (
|
453 |
+
seq[: min(batch_size, mbz)].clone()
|
454 |
+
if tim == 0
|
455 |
+
else seq[mbz * tim : mbz * (tim + 1)].clone()
|
456 |
+
)
|
457 |
+
guider_seq2 = (
|
458 |
+
(
|
459 |
+
guider_seq[: min(batch_size, mbz)].clone()
|
460 |
+
if tim == 0
|
461 |
+
else guider_seq[mbz * tim : mbz * (tim + 1)].clone()
|
462 |
+
)
|
463 |
+
if guider_seq is not None
|
464 |
+
else None
|
465 |
+
)
|
466 |
+
output_list.append(
|
467 |
+
my_filling_sequence(
|
468 |
+
model,
|
469 |
+
args,
|
470 |
+
input_seq,
|
471 |
+
batch_size=min(batch_size, mbz),
|
472 |
+
get_masks_and_position_ids=get_masks_and_position_ids_stage1,
|
473 |
+
text_len=text_len,
|
474 |
+
frame_len=frame_len,
|
475 |
+
strategy=strategy_cogview2,
|
476 |
+
strategy2=strategy_cogvideo,
|
477 |
+
log_text_attention_weights=video_log_text_attention_weights,
|
478 |
+
guider_seq=guider_seq2,
|
479 |
+
guider_text_len=guider_text_len,
|
480 |
+
guidance_alpha=args.guidance_alpha,
|
481 |
+
limited_spatial_channel_mem=True,
|
482 |
+
mode_stage1=True,
|
483 |
+
)[0]
|
484 |
+
)
|
485 |
+
|
486 |
+
output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len :]
|
487 |
+
|
488 |
+
if args.both_stages:
|
489 |
+
move_start_time = time.time()
|
490 |
+
logging.debug("moving stage 1 model to cpu")
|
491 |
+
model = model.cpu()
|
492 |
+
torch.cuda.empty_cache()
|
493 |
+
logging.debug(
|
494 |
+
"moving in model1 takes time: {:.2f}".format(
|
495 |
+
time.time() - move_start_time
|
496 |
+
)
|
497 |
+
)
|
498 |
+
|
499 |
+
# decoding
|
500 |
+
imgs, sred_imgs, txts = [], [], []
|
501 |
+
for seq in output_tokens:
|
502 |
+
decoded_imgs = [
|
503 |
+
torch.nn.functional.interpolate(
|
504 |
+
tokenizer.decode(image_ids=seq.tolist()[i * 400 : (i + 1) * 400]),
|
505 |
+
size=(height, width),
|
506 |
+
)
|
507 |
+
for i in range(total_frames)
|
508 |
+
]
|
509 |
+
imgs.append(decoded_imgs) # only the last image (target)
|
510 |
+
|
511 |
+
assert len(imgs) == batch_size
|
512 |
+
return imgs
|
513 |
+
# save_tokens = output_tokens[:, :+total_frames*400].reshape(-1, total_frames, 400).cpu()
|
514 |
+
# if outputdir is not None:
|
515 |
+
# for clip_i in range(len(imgs)):
|
516 |
+
# # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True)
|
517 |
+
# my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False)
|
518 |
+
# os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25")
|
519 |
+
# torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt'))
|
520 |
+
|
521 |
+
# logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time))
|
522 |
+
|
523 |
+
# return save_tokens
|
524 |
+
|
525 |
+
# ======================================================================================================
|
526 |
+
|
527 |
+
if args.stage_1 or args.both_stages:
|
528 |
+
if args.input_source != "interactive":
|
529 |
+
with open(args.input_source, "r") as fin:
|
530 |
+
promptlist = fin.readlines()
|
531 |
+
promptlist = [p.strip() for p in promptlist]
|
532 |
+
else:
|
533 |
+
promptlist = None
|
534 |
+
|
535 |
+
now_qi = -1
|
536 |
+
while True:
|
537 |
+
now_qi += 1
|
538 |
+
|
539 |
+
if promptlist is not None: # with input-source
|
540 |
+
if args.multi_gpu:
|
541 |
+
if now_qi % dist.get_world_size() != dist.get_rank():
|
542 |
+
continue
|
543 |
+
rk = dist.get_rank()
|
544 |
+
else:
|
545 |
+
rk = 0
|
546 |
+
raw_text = promptlist[now_qi]
|
547 |
+
raw_text = raw_text.strip()
|
548 |
+
print(f"Working on Line No. {now_qi} on {rk}... [{raw_text}]")
|
549 |
+
else: # interactive
|
550 |
+
raw_text = input("\nPlease Input Query (stop to exit) >>> ")
|
551 |
+
raw_text = raw_text.strip()
|
552 |
+
if not raw_text:
|
553 |
+
print("Query should not be empty!")
|
554 |
+
continue
|
555 |
+
if raw_text == "stop":
|
556 |
+
return
|
557 |
+
|
558 |
+
try:
|
559 |
+
path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
|
560 |
+
parent_given_tokens, imgs = process_stage1(
|
561 |
+
model_stage1,
|
562 |
+
raw_text,
|
563 |
+
duration=4.0,
|
564 |
+
video_raw_text=raw_text,
|
565 |
+
video_guidance_text="视频",
|
566 |
+
image_text_suffix=" 高清摄影",
|
567 |
+
outputdir=path if args.stage_1 else None,
|
568 |
+
batch_size=args.batch_size,
|
569 |
+
)
|
570 |
+
if args.stage_1 and not args.both_stages:
|
571 |
+
print("only stage 1")
|
572 |
+
return imgs
|
573 |
+
|
574 |
+
if args.both_stages:
|
575 |
+
videos = process_stage2(
|
576 |
+
model_stage2,
|
577 |
+
raw_text,
|
578 |
+
duration=duration,
|
579 |
+
video_raw_text=raw_text + " 视频",
|
580 |
+
video_guidance_text="视频",
|
581 |
+
parent_given_tokens=parent_given_tokens,
|
582 |
+
outputdir=path,
|
583 |
+
gpu_rank=0,
|
584 |
+
gpu_parallel_size=1,
|
585 |
+
) # TODO: 修改
|
586 |
+
return videos
|
587 |
+
except (ValueError, FileNotFoundError) as e:
|
588 |
+
print(e)
|
589 |
+
continue
|
590 |
+
|
591 |
+
elif args.stage_2:
|
592 |
+
sample_dirs = os.listdir(args.output_path)
|
593 |
+
for sample in sample_dirs:
|
594 |
+
raw_text = sample.split("_")[-1]
|
595 |
+
path = os.path.join(args.output_path, sample, "Interp")
|
596 |
+
parent_given_tokens = torch.load(
|
597 |
+
os.path.join(args.output_path, sample, "frame_tokens.pt")
|
598 |
+
)
|
599 |
+
|
600 |
+
process_stage2(
|
601 |
+
raw_text,
|
602 |
+
duration=2.0,
|
603 |
+
video_raw_text=raw_text + " 视频",
|
604 |
+
video_guidance_text="视频",
|
605 |
+
parent_given_tokens=parent_given_tokens,
|
606 |
+
outputdir=path,
|
607 |
+
gpu_rank=0,
|
608 |
+
gpu_parallel_size=1,
|
609 |
+
) # TODO: 修改
|
610 |
+
|
611 |
+
else:
|
612 |
+
assert False
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/Model_License
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The CogVideo License
|
2 |
+
|
3 |
+
Section I: PREAMBLE
|
4 |
+
|
5 |
+
Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
|
6 |
+
|
7 |
+
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
|
8 |
+
|
9 |
+
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
|
10 |
+
|
11 |
+
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
12 |
+
|
13 |
+
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
|
14 |
+
|
15 |
+
NOW THEREFORE, You and Licensor agree as follows:
|
16 |
+
|
17 |
+
1. Definitions
|
18 |
+
|
19 |
+
- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
20 |
+
- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
21 |
+
- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
|
22 |
+
- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
|
23 |
+
- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
|
24 |
+
- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
25 |
+
- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
|
26 |
+
- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
|
27 |
+
- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
|
28 |
+
- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
|
29 |
+
- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
30 |
+
- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
|
31 |
+
|
32 |
+
Section II: INTELLECTUAL PROPERTY RIGHTS
|
33 |
+
|
34 |
+
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
35 |
+
|
36 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
|
37 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
|
38 |
+
|
39 |
+
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
40 |
+
|
41 |
+
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
42 |
+
Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
|
43 |
+
You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
|
44 |
+
You must cause any modified files to carry prominent notices stating that You changed the files;
|
45 |
+
You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
|
46 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
47 |
+
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
|
48 |
+
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
|
49 |
+
|
50 |
+
Section IV: OTHER PROVISIONS
|
51 |
+
|
52 |
+
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
|
53 |
+
8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
|
54 |
+
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
|
55 |
+
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
56 |
+
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
57 |
+
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
58 |
+
|
59 |
+
END OF TERMS AND CONDITIONS
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
Attachment A
|
65 |
+
|
66 |
+
Use Restrictions
|
67 |
+
|
68 |
+
You agree not to use the Model or Derivatives of the Model:
|
69 |
+
- In any way that violates any applicable national, federal, state, local or international law or regulation;
|
70 |
+
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
71 |
+
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
|
72 |
+
- To generate or disseminate personal identifiable information that can be used to harm an individual;
|
73 |
+
- To defame, disparage or otherwise harass others;
|
74 |
+
- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
|
75 |
+
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
|
76 |
+
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
77 |
+
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
|
78 |
+
- To provide medical advice and medical results interpretation;
|
79 |
+
- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/cluster_label2.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b87880fdbe89670f12844377b9cf97a9733b1f54e3a9b73cbb9835084c4e02ec
|
3 |
+
size 160128
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/coglm_strategy.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : coglm_strategy.py
|
4 |
+
@Time : 2021/10/08 22:22:42
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504):
|
20 |
+
# This function has been mostly taken from huggingface conversational ai code at
|
21 |
+
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
22 |
+
|
23 |
+
if top_k > 0:
|
24 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
25 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
26 |
+
logits[indices_to_remove] = filter_value
|
27 |
+
|
28 |
+
if top_p > 0.0:
|
29 |
+
# convert to 1D
|
30 |
+
logits = logits.view(logits.size()[1]).contiguous()
|
31 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
32 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
33 |
+
|
34 |
+
# Remove tokens with cumulative probability above the threshold
|
35 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
36 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
37 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
38 |
+
sorted_indices_to_remove[..., 0] = 0
|
39 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
40 |
+
logits[indices_to_remove] = filter_value
|
41 |
+
# going back to 2D
|
42 |
+
logits = logits.view(1, -1).contiguous()
|
43 |
+
|
44 |
+
return logits
|
45 |
+
|
46 |
+
|
47 |
+
class CoglmStrategy:
|
48 |
+
def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, temperature2=0.89):
|
49 |
+
self.invalid_slices = invalid_slices
|
50 |
+
self.temperature = temperature
|
51 |
+
self.temperature2 = temperature2
|
52 |
+
self.topk = top_k
|
53 |
+
self.top_p = top_p
|
54 |
+
self.eps = eps
|
55 |
+
if end_tokens is None:
|
56 |
+
end_tokens = []
|
57 |
+
self.end_tokens = end_tokens
|
58 |
+
self._is_done = False
|
59 |
+
self.outlier_count_down = torch.zeros(16)
|
60 |
+
self.vis_list = [[]for i in range(16)]
|
61 |
+
self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
|
62 |
+
self.start_pos = -1
|
63 |
+
self.white_cluster = []
|
64 |
+
# self.fout = open('tmp.txt', 'w')
|
65 |
+
|
66 |
+
@property
|
67 |
+
def is_done(self) -> bool:
|
68 |
+
return self._is_done
|
69 |
+
|
70 |
+
def forward(self, logits, tokens, mems, temperature=None, temperature2=None):
|
71 |
+
if temperature is None:
|
72 |
+
temperature = self.temperature
|
73 |
+
if temperature2 is None:
|
74 |
+
temperature2 = self.temperature2
|
75 |
+
logits = logits / temperature
|
76 |
+
for invalid_slice in self.invalid_slices:
|
77 |
+
logits[..., invalid_slice] = -65504
|
78 |
+
|
79 |
+
rprobs = F.softmax(logits.float(), dim=-1)
|
80 |
+
c = self.cluster_labels.expand(*rprobs.shape)
|
81 |
+
cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
|
82 |
+
# self.fout.write(str(tokens.shape[-1])+ ' ' + str(cprobs.topk(10)) + '\n')
|
83 |
+
# self.fout.flush()
|
84 |
+
best_scores, best_clusters = cprobs.topk(self.topk)
|
85 |
+
bz = logits.shape[0]
|
86 |
+
for i in range(bz):
|
87 |
+
selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
|
88 |
+
logits[i, self.cluster_labels != selected_cluster] = -65504
|
89 |
+
|
90 |
+
# logits = top_k_logits(logits, self.topk, self.top_p)
|
91 |
+
probs = F.softmax(logits.float()/temperature2, dim=-1) # float is essetial, due to a bug in Pytorch
|
92 |
+
pred = torch.multinomial(probs, num_samples=1)
|
93 |
+
|
94 |
+
if pred.numel() == 1 and pred.item() in self.end_tokens:
|
95 |
+
self._is_done = True
|
96 |
+
tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
|
97 |
+
return tokens, mems
|
98 |
+
|
99 |
+
def finalize(self, tokens, mems):
|
100 |
+
self._is_done = False
|
101 |
+
return tokens, mems
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/cogvideo_pipeline.py
ADDED
@@ -0,0 +1,1341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@File : cogvideo_pipeline.py
|
4 |
+
@Time : 2022/07/15 11:24:56
|
5 |
+
@Author : Wenyi Hong
|
6 |
+
@Version : 1.0
|
7 |
+
@Contact : [email protected]
|
8 |
+
"""
|
9 |
+
|
10 |
+
# here put the import lib
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
import torch
|
15 |
+
import argparse
|
16 |
+
import time
|
17 |
+
from torchvision.utils import save_image
|
18 |
+
import stat
|
19 |
+
from videogen_hub.depend.icetk import icetk as tokenizer
|
20 |
+
import logging, sys
|
21 |
+
|
22 |
+
import torch.distributed as dist
|
23 |
+
|
24 |
+
tokenizer.add_special_tokens(
|
25 |
+
["<start_of_image>", "<start_of_english>", "<start_of_chinese>"]
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
from SwissArmyTransformer import get_args
|
30 |
+
from SwissArmyTransformer.data_utils import BinaryDataset, make_loaders
|
31 |
+
from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
|
32 |
+
from SwissArmyTransformer.generation.utils import (
|
33 |
+
timed_name,
|
34 |
+
save_multiple_images,
|
35 |
+
generate_continually,
|
36 |
+
)
|
37 |
+
from SwissArmyTransformer.resources import auto_create
|
38 |
+
|
39 |
+
from .models.cogvideo_cache_model import CogVideoCacheModel
|
40 |
+
from .coglm_strategy import CoglmStrategy
|
41 |
+
|
42 |
+
|
43 |
+
def get_masks_and_position_ids_stage1(data, textlen, framelen):
|
44 |
+
# Extract batch size and sequence length.
|
45 |
+
tokens = data
|
46 |
+
seq_length = len(data[0])
|
47 |
+
# Attention mask (lower triangular).
|
48 |
+
attention_mask = torch.ones(
|
49 |
+
(1, textlen + framelen, textlen + framelen), device=data.device
|
50 |
+
)
|
51 |
+
attention_mask[:, :textlen, textlen:] = 0
|
52 |
+
attention_mask[:, textlen:, textlen:].tril_()
|
53 |
+
attention_mask.unsqueeze_(1)
|
54 |
+
# Unaligned version
|
55 |
+
position_ids = torch.zeros(seq_length, dtype=torch.long, device=data.device)
|
56 |
+
torch.arange(
|
57 |
+
textlen, out=position_ids[:textlen], dtype=torch.long, device=data.device
|
58 |
+
)
|
59 |
+
torch.arange(
|
60 |
+
512,
|
61 |
+
512 + seq_length - textlen,
|
62 |
+
out=position_ids[textlen:],
|
63 |
+
dtype=torch.long,
|
64 |
+
device=data.device,
|
65 |
+
)
|
66 |
+
position_ids = position_ids.unsqueeze(0)
|
67 |
+
|
68 |
+
return tokens, attention_mask, position_ids
|
69 |
+
|
70 |
+
|
71 |
+
def get_masks_and_position_ids_stage2(data, textlen, framelen):
|
72 |
+
# Extract batch size and sequence length.
|
73 |
+
tokens = data
|
74 |
+
seq_length = len(data[0])
|
75 |
+
|
76 |
+
# Attention mask (lower triangular).
|
77 |
+
attention_mask = torch.ones(
|
78 |
+
(1, textlen + framelen, textlen + framelen), device=data.device
|
79 |
+
)
|
80 |
+
attention_mask[:, :textlen, textlen:] = 0
|
81 |
+
attention_mask[:, textlen:, textlen:].tril_()
|
82 |
+
attention_mask.unsqueeze_(1)
|
83 |
+
|
84 |
+
# Unaligned version
|
85 |
+
position_ids = torch.zeros(seq_length, dtype=torch.long, device=data.device)
|
86 |
+
torch.arange(
|
87 |
+
textlen, out=position_ids[:textlen], dtype=torch.long, device=data.device
|
88 |
+
)
|
89 |
+
frame_num = (seq_length - textlen) // framelen
|
90 |
+
assert frame_num == 5
|
91 |
+
torch.arange(
|
92 |
+
512,
|
93 |
+
512 + framelen,
|
94 |
+
out=position_ids[textlen : textlen + framelen],
|
95 |
+
dtype=torch.long,
|
96 |
+
device=data.device,
|
97 |
+
)
|
98 |
+
torch.arange(
|
99 |
+
512 + framelen * 2,
|
100 |
+
512 + framelen * 3,
|
101 |
+
out=position_ids[textlen + framelen : textlen + framelen * 2],
|
102 |
+
dtype=torch.long,
|
103 |
+
device=data.device,
|
104 |
+
)
|
105 |
+
torch.arange(
|
106 |
+
512 + framelen * (frame_num - 1),
|
107 |
+
512 + framelen * frame_num,
|
108 |
+
out=position_ids[textlen + framelen * 2 : textlen + framelen * 3],
|
109 |
+
dtype=torch.long,
|
110 |
+
device=data.device,
|
111 |
+
)
|
112 |
+
torch.arange(
|
113 |
+
512 + framelen * 1,
|
114 |
+
512 + framelen * 2,
|
115 |
+
out=position_ids[textlen + framelen * 3 : textlen + framelen * 4],
|
116 |
+
dtype=torch.long,
|
117 |
+
device=data.device,
|
118 |
+
)
|
119 |
+
torch.arange(
|
120 |
+
512 + framelen * 3,
|
121 |
+
512 + framelen * 4,
|
122 |
+
out=position_ids[textlen + framelen * 4 : textlen + framelen * 5],
|
123 |
+
dtype=torch.long,
|
124 |
+
device=data.device,
|
125 |
+
)
|
126 |
+
|
127 |
+
position_ids = position_ids.unsqueeze(0)
|
128 |
+
|
129 |
+
return tokens, attention_mask, position_ids
|
130 |
+
|
131 |
+
|
132 |
+
def my_update_mems(
|
133 |
+
hiddens, mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len
|
134 |
+
):
|
135 |
+
if hiddens is None:
|
136 |
+
return None, mems_indexs
|
137 |
+
mem_num = len(hiddens)
|
138 |
+
ret_mem = []
|
139 |
+
with torch.no_grad():
|
140 |
+
for id in range(mem_num):
|
141 |
+
if hiddens[id][0] is None:
|
142 |
+
ret_mem.append(None)
|
143 |
+
else:
|
144 |
+
if (
|
145 |
+
id == 0
|
146 |
+
and limited_spatial_channel_mem
|
147 |
+
and mems_indexs[id] + hiddens[0][0].shape[1] >= text_len + frame_len
|
148 |
+
):
|
149 |
+
if mems_indexs[id] == 0:
|
150 |
+
for layer, hidden in enumerate(hiddens[id]):
|
151 |
+
mems_buffers[id][layer, :, :text_len] = hidden.expand(
|
152 |
+
mems_buffers[id].shape[1], -1, -1
|
153 |
+
)[:, :text_len]
|
154 |
+
new_mem_len_part2 = (
|
155 |
+
mems_indexs[id] + hiddens[0][0].shape[1] - text_len
|
156 |
+
) % frame_len
|
157 |
+
if new_mem_len_part2 > 0:
|
158 |
+
for layer, hidden in enumerate(hiddens[id]):
|
159 |
+
mems_buffers[id][
|
160 |
+
layer, :, text_len : text_len + new_mem_len_part2
|
161 |
+
] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[
|
162 |
+
:, -new_mem_len_part2:
|
163 |
+
]
|
164 |
+
mems_indexs[id] = text_len + new_mem_len_part2
|
165 |
+
else:
|
166 |
+
for layer, hidden in enumerate(hiddens[id]):
|
167 |
+
mems_buffers[id][
|
168 |
+
layer,
|
169 |
+
:,
|
170 |
+
mems_indexs[id] : mems_indexs[id] + hidden.shape[1],
|
171 |
+
] = hidden.expand(mems_buffers[id].shape[1], -1, -1)
|
172 |
+
mems_indexs[id] += hidden.shape[1]
|
173 |
+
ret_mem.append(mems_buffers[id][:, :, : mems_indexs[id]])
|
174 |
+
return ret_mem, mems_indexs
|
175 |
+
|
176 |
+
|
177 |
+
def my_save_multiple_images(imgs, path, subdir, debug=True):
|
178 |
+
# imgs: list of tensor images
|
179 |
+
if debug:
|
180 |
+
imgs = torch.cat(imgs, dim=0)
|
181 |
+
print("\nSave to: ", path, flush=True)
|
182 |
+
save_image(imgs, path, normalize=True)
|
183 |
+
else:
|
184 |
+
print("\nSave to: ", path, flush=True)
|
185 |
+
single_frame_path = os.path.join(path, subdir)
|
186 |
+
os.makedirs(single_frame_path, exist_ok=True)
|
187 |
+
for i in range(len(imgs)):
|
188 |
+
save_image(
|
189 |
+
imgs[i],
|
190 |
+
os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'),
|
191 |
+
normalize=True,
|
192 |
+
)
|
193 |
+
os.chmod(
|
194 |
+
os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'),
|
195 |
+
stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU,
|
196 |
+
)
|
197 |
+
save_image(
|
198 |
+
torch.cat(imgs, dim=0),
|
199 |
+
os.path.join(single_frame_path, f"frame_concat.jpg"),
|
200 |
+
normalize=True,
|
201 |
+
)
|
202 |
+
os.chmod(
|
203 |
+
os.path.join(single_frame_path, f"frame_concat.jpg"),
|
204 |
+
stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU,
|
205 |
+
)
|
206 |
+
|
207 |
+
|
208 |
+
def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
|
209 |
+
# The fisrt token's position id of the frame that the next token belongs to;
|
210 |
+
if total_len < text_len:
|
211 |
+
return None
|
212 |
+
return (total_len - text_len) // frame_len * frame_len + text_len
|
213 |
+
|
214 |
+
|
215 |
+
def my_filling_sequence(
|
216 |
+
model,
|
217 |
+
args,
|
218 |
+
seq,
|
219 |
+
batch_size,
|
220 |
+
get_masks_and_position_ids,
|
221 |
+
text_len,
|
222 |
+
frame_len,
|
223 |
+
strategy=BaseStrategy(),
|
224 |
+
strategy2=BaseStrategy(),
|
225 |
+
mems=None,
|
226 |
+
log_text_attention_weights=0, # default to 0: no artificial change
|
227 |
+
mode_stage1=True,
|
228 |
+
enforce_no_swin=False,
|
229 |
+
guider_seq=None,
|
230 |
+
guider_text_len=0,
|
231 |
+
guidance_alpha=1,
|
232 |
+
limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
|
233 |
+
**kw_args,
|
234 |
+
):
|
235 |
+
"""
|
236 |
+
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
|
237 |
+
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
|
238 |
+
cache, should be first mems.shape[1] parts of context_tokens.
|
239 |
+
mems are the first-level citizens here, but we don't assume what is memorized.
|
240 |
+
input mems are used when multi-phase generation.
|
241 |
+
"""
|
242 |
+
if guider_seq is not None:
|
243 |
+
logging.debug("Using Guidance In Inference")
|
244 |
+
if limited_spatial_channel_mem:
|
245 |
+
logging.debug("Limit spatial-channel's mem to current frame")
|
246 |
+
assert len(seq.shape) == 2
|
247 |
+
|
248 |
+
# building the initial tokens, attention_mask, and position_ids
|
249 |
+
actual_context_length = 0
|
250 |
+
|
251 |
+
while seq[-1][actual_context_length] >= 0: # the last seq has least given tokens
|
252 |
+
actual_context_length += 1 # [0, context_length-1] are given
|
253 |
+
assert actual_context_length > 0
|
254 |
+
current_frame_num = (actual_context_length - text_len) // frame_len
|
255 |
+
assert current_frame_num >= 0
|
256 |
+
context_length = text_len + current_frame_num * frame_len
|
257 |
+
|
258 |
+
tokens, attention_mask, position_ids = get_masks_and_position_ids(
|
259 |
+
seq, text_len, frame_len
|
260 |
+
)
|
261 |
+
tokens = tokens[..., :context_length]
|
262 |
+
input_tokens = tokens.clone()
|
263 |
+
|
264 |
+
if guider_seq is not None:
|
265 |
+
guider_index_delta = text_len - guider_text_len
|
266 |
+
guider_tokens, guider_attention_mask, guider_position_ids = (
|
267 |
+
get_masks_and_position_ids(guider_seq, guider_text_len, frame_len)
|
268 |
+
)
|
269 |
+
guider_tokens = guider_tokens[..., : context_length - guider_index_delta]
|
270 |
+
guider_input_tokens = guider_tokens.clone()
|
271 |
+
|
272 |
+
for fid in range(current_frame_num):
|
273 |
+
input_tokens[:, text_len + 400 * fid] = tokenizer["<start_of_image>"]
|
274 |
+
if guider_seq is not None:
|
275 |
+
guider_input_tokens[:, guider_text_len + 400 * fid] = tokenizer[
|
276 |
+
"<start_of_image>"
|
277 |
+
]
|
278 |
+
|
279 |
+
attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
|
280 |
+
# initialize generation
|
281 |
+
counter = context_length - 1 # Last fixed index is ``counter''
|
282 |
+
index = 0 # Next forward starting index, also the length of cache.
|
283 |
+
mems_buffers_on_GPU = False
|
284 |
+
mems_indexs = [0, 0]
|
285 |
+
mems_len = [
|
286 |
+
(400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74,
|
287 |
+
5 * 400 + 74,
|
288 |
+
]
|
289 |
+
mems_buffers = [
|
290 |
+
torch.zeros(
|
291 |
+
args.num_layers,
|
292 |
+
batch_size,
|
293 |
+
mem_len,
|
294 |
+
args.hidden_size * 2,
|
295 |
+
dtype=next(model.parameters()).dtype,
|
296 |
+
)
|
297 |
+
for mem_len in mems_len
|
298 |
+
]
|
299 |
+
|
300 |
+
if guider_seq is not None:
|
301 |
+
guider_attention_mask = guider_attention_mask.type_as(
|
302 |
+
next(model.parameters())
|
303 |
+
) # if fp16
|
304 |
+
guider_mems_buffers = [
|
305 |
+
torch.zeros(
|
306 |
+
args.num_layers,
|
307 |
+
batch_size,
|
308 |
+
mem_len,
|
309 |
+
args.hidden_size * 2,
|
310 |
+
dtype=next(model.parameters()).dtype,
|
311 |
+
)
|
312 |
+
for mem_len in mems_len
|
313 |
+
]
|
314 |
+
guider_mems_indexs = [0, 0]
|
315 |
+
guider_mems = None
|
316 |
+
|
317 |
+
torch.cuda.empty_cache()
|
318 |
+
# step-by-step generation
|
319 |
+
while counter < len(seq[0]) - 1:
|
320 |
+
# we have generated counter+1 tokens
|
321 |
+
# Now, we want to generate seq[counter + 1],
|
322 |
+
# token[:, index: counter+1] needs forwarding.
|
323 |
+
if index == 0:
|
324 |
+
group_size = (
|
325 |
+
2
|
326 |
+
if (input_tokens.shape[0] == batch_size and not mode_stage1)
|
327 |
+
else batch_size
|
328 |
+
)
|
329 |
+
|
330 |
+
logits_all = None
|
331 |
+
for batch_idx in range(0, input_tokens.shape[0], group_size):
|
332 |
+
logits, *output_per_layers = model(
|
333 |
+
input_tokens[batch_idx : batch_idx + group_size, index:],
|
334 |
+
position_ids[..., index : counter + 1],
|
335 |
+
attention_mask, # TODO memlen
|
336 |
+
mems=mems,
|
337 |
+
text_len=text_len,
|
338 |
+
frame_len=frame_len,
|
339 |
+
counter=counter,
|
340 |
+
log_text_attention_weights=log_text_attention_weights,
|
341 |
+
enforce_no_swin=enforce_no_swin,
|
342 |
+
**kw_args,
|
343 |
+
)
|
344 |
+
logits_all = (
|
345 |
+
torch.cat((logits_all, logits), dim=0)
|
346 |
+
if logits_all is not None
|
347 |
+
else logits
|
348 |
+
)
|
349 |
+
mem_kv01 = [
|
350 |
+
[o["mem_kv"][0] for o in output_per_layers],
|
351 |
+
[o["mem_kv"][1] for o in output_per_layers],
|
352 |
+
]
|
353 |
+
next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
|
354 |
+
text_len, frame_len, mem_kv01[0][0].shape[1]
|
355 |
+
)
|
356 |
+
for id, mem_kv in enumerate(mem_kv01):
|
357 |
+
for layer, mem_kv_perlayer in enumerate(mem_kv):
|
358 |
+
if limited_spatial_channel_mem and id == 0:
|
359 |
+
mems_buffers[id][
|
360 |
+
layer, batch_idx : batch_idx + group_size, :text_len
|
361 |
+
] = mem_kv_perlayer.expand(
|
362 |
+
min(group_size, input_tokens.shape[0] - batch_idx),
|
363 |
+
-1,
|
364 |
+
-1,
|
365 |
+
)[
|
366 |
+
:, :text_len
|
367 |
+
]
|
368 |
+
mems_buffers[id][
|
369 |
+
layer,
|
370 |
+
batch_idx : batch_idx + group_size,
|
371 |
+
text_len : text_len
|
372 |
+
+ mem_kv_perlayer.shape[1]
|
373 |
+
- next_tokens_frame_begin_id,
|
374 |
+
] = mem_kv_perlayer.expand(
|
375 |
+
min(group_size, input_tokens.shape[0] - batch_idx),
|
376 |
+
-1,
|
377 |
+
-1,
|
378 |
+
)[
|
379 |
+
:, next_tokens_frame_begin_id:
|
380 |
+
]
|
381 |
+
else:
|
382 |
+
mems_buffers[id][
|
383 |
+
layer,
|
384 |
+
batch_idx : batch_idx + group_size,
|
385 |
+
: mem_kv_perlayer.shape[1],
|
386 |
+
] = mem_kv_perlayer.expand(
|
387 |
+
min(group_size, input_tokens.shape[0] - batch_idx),
|
388 |
+
-1,
|
389 |
+
-1,
|
390 |
+
)
|
391 |
+
mems_indexs[0], mems_indexs[1] = (
|
392 |
+
mem_kv01[0][0].shape[1],
|
393 |
+
mem_kv01[1][0].shape[1],
|
394 |
+
)
|
395 |
+
if limited_spatial_channel_mem:
|
396 |
+
mems_indexs[0] -= next_tokens_frame_begin_id - text_len
|
397 |
+
|
398 |
+
mems = [mems_buffers[id][:, :, : mems_indexs[id]] for id in range(2)]
|
399 |
+
logits = logits_all
|
400 |
+
|
401 |
+
# Guider
|
402 |
+
if guider_seq is not None:
|
403 |
+
guider_logits_all = None
|
404 |
+
for batch_idx in range(0, guider_input_tokens.shape[0], group_size):
|
405 |
+
guider_logits, *guider_output_per_layers = model(
|
406 |
+
guider_input_tokens[
|
407 |
+
batch_idx : batch_idx + group_size,
|
408 |
+
max(index - guider_index_delta, 0) :,
|
409 |
+
],
|
410 |
+
guider_position_ids[
|
411 |
+
...,
|
412 |
+
max(index - guider_index_delta, 0) : counter
|
413 |
+
+ 1
|
414 |
+
- guider_index_delta,
|
415 |
+
],
|
416 |
+
guider_attention_mask,
|
417 |
+
mems=guider_mems,
|
418 |
+
text_len=guider_text_len,
|
419 |
+
frame_len=frame_len,
|
420 |
+
counter=counter - guider_index_delta,
|
421 |
+
log_text_attention_weights=log_text_attention_weights,
|
422 |
+
enforce_no_swin=enforce_no_swin,
|
423 |
+
**kw_args,
|
424 |
+
)
|
425 |
+
guider_logits_all = (
|
426 |
+
torch.cat((guider_logits_all, guider_logits), dim=0)
|
427 |
+
if guider_logits_all is not None
|
428 |
+
else guider_logits
|
429 |
+
)
|
430 |
+
guider_mem_kv01 = [
|
431 |
+
[o["mem_kv"][0] for o in guider_output_per_layers],
|
432 |
+
[o["mem_kv"][1] for o in guider_output_per_layers],
|
433 |
+
]
|
434 |
+
for id, guider_mem_kv in enumerate(guider_mem_kv01):
|
435 |
+
for layer, guider_mem_kv_perlayer in enumerate(guider_mem_kv):
|
436 |
+
if limited_spatial_channel_mem and id == 0:
|
437 |
+
guider_mems_buffers[id][
|
438 |
+
layer,
|
439 |
+
batch_idx : batch_idx + group_size,
|
440 |
+
:guider_text_len,
|
441 |
+
] = guider_mem_kv_perlayer.expand(
|
442 |
+
min(group_size, input_tokens.shape[0] - batch_idx),
|
443 |
+
-1,
|
444 |
+
-1,
|
445 |
+
)[
|
446 |
+
:, :guider_text_len
|
447 |
+
]
|
448 |
+
guider_next_tokens_frame_begin_id = (
|
449 |
+
calc_next_tokens_frame_begin_id(
|
450 |
+
guider_text_len,
|
451 |
+
frame_len,
|
452 |
+
guider_mem_kv_perlayer.shape[1],
|
453 |
+
)
|
454 |
+
)
|
455 |
+
guider_mems_buffers[id][
|
456 |
+
layer,
|
457 |
+
batch_idx : batch_idx + group_size,
|
458 |
+
guider_text_len : guider_text_len
|
459 |
+
+ guider_mem_kv_perlayer.shape[1]
|
460 |
+
- guider_next_tokens_frame_begin_id,
|
461 |
+
] = guider_mem_kv_perlayer.expand(
|
462 |
+
min(group_size, input_tokens.shape[0] - batch_idx),
|
463 |
+
-1,
|
464 |
+
-1,
|
465 |
+
)[
|
466 |
+
:, guider_next_tokens_frame_begin_id:
|
467 |
+
]
|
468 |
+
else:
|
469 |
+
guider_mems_buffers[id][
|
470 |
+
layer,
|
471 |
+
batch_idx : batch_idx + group_size,
|
472 |
+
: guider_mem_kv_perlayer.shape[1],
|
473 |
+
] = guider_mem_kv_perlayer.expand(
|
474 |
+
min(group_size, input_tokens.shape[0] - batch_idx),
|
475 |
+
-1,
|
476 |
+
-1,
|
477 |
+
)
|
478 |
+
guider_mems_indexs[0], guider_mems_indexs[1] = (
|
479 |
+
guider_mem_kv01[0][0].shape[1],
|
480 |
+
guider_mem_kv01[1][0].shape[1],
|
481 |
+
)
|
482 |
+
if limited_spatial_channel_mem:
|
483 |
+
guider_mems_indexs[0] -= (
|
484 |
+
guider_next_tokens_frame_begin_id - guider_text_len
|
485 |
+
)
|
486 |
+
guider_mems = [
|
487 |
+
guider_mems_buffers[id][:, :, : guider_mems_indexs[id]]
|
488 |
+
for id in range(2)
|
489 |
+
]
|
490 |
+
guider_logits = guider_logits_all
|
491 |
+
else:
|
492 |
+
if not mems_buffers_on_GPU:
|
493 |
+
if not mode_stage1:
|
494 |
+
torch.cuda.empty_cache()
|
495 |
+
for idx, mem in enumerate(mems):
|
496 |
+
mems[idx] = mem.to(next(model.parameters()).device)
|
497 |
+
if guider_seq is not None:
|
498 |
+
for idx, mem in enumerate(guider_mems):
|
499 |
+
guider_mems[idx] = mem.to(next(model.parameters()).device)
|
500 |
+
else:
|
501 |
+
torch.cuda.empty_cache()
|
502 |
+
for idx, mem_buffer in enumerate(mems_buffers):
|
503 |
+
mems_buffers[idx] = mem_buffer.to(
|
504 |
+
next(model.parameters()).device
|
505 |
+
)
|
506 |
+
mems = [
|
507 |
+
mems_buffers[id][:, :, : mems_indexs[id]] for id in range(2)
|
508 |
+
]
|
509 |
+
if guider_seq is not None:
|
510 |
+
for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
|
511 |
+
guider_mems_buffers[idx] = guider_mem_buffer.to(
|
512 |
+
next(model.parameters()).device
|
513 |
+
)
|
514 |
+
guider_mems = [
|
515 |
+
guider_mems_buffers[id][:, :, : guider_mems_indexs[id]]
|
516 |
+
for id in range(2)
|
517 |
+
]
|
518 |
+
mems_buffers_on_GPU = True
|
519 |
+
|
520 |
+
logits, *output_per_layers = model(
|
521 |
+
input_tokens[:, index:],
|
522 |
+
position_ids[..., index : counter + 1],
|
523 |
+
attention_mask, # TODO memlen
|
524 |
+
mems=mems,
|
525 |
+
text_len=text_len,
|
526 |
+
frame_len=frame_len,
|
527 |
+
counter=counter,
|
528 |
+
log_text_attention_weights=log_text_attention_weights,
|
529 |
+
enforce_no_swin=enforce_no_swin,
|
530 |
+
limited_spatial_channel_mem=limited_spatial_channel_mem,
|
531 |
+
**kw_args,
|
532 |
+
)
|
533 |
+
mem_kv0, mem_kv1 = [o["mem_kv"][0] for o in output_per_layers], [
|
534 |
+
o["mem_kv"][1] for o in output_per_layers
|
535 |
+
]
|
536 |
+
|
537 |
+
if guider_seq is not None:
|
538 |
+
guider_logits, *guider_output_per_layers = model(
|
539 |
+
guider_input_tokens[:, max(index - guider_index_delta, 0) :],
|
540 |
+
guider_position_ids[
|
541 |
+
...,
|
542 |
+
max(index - guider_index_delta, 0) : counter
|
543 |
+
+ 1
|
544 |
+
- guider_index_delta,
|
545 |
+
],
|
546 |
+
guider_attention_mask,
|
547 |
+
mems=guider_mems,
|
548 |
+
text_len=guider_text_len,
|
549 |
+
frame_len=frame_len,
|
550 |
+
counter=counter - guider_index_delta,
|
551 |
+
log_text_attention_weights=0,
|
552 |
+
enforce_no_swin=enforce_no_swin,
|
553 |
+
limited_spatial_channel_mem=limited_spatial_channel_mem,
|
554 |
+
**kw_args,
|
555 |
+
)
|
556 |
+
guider_mem_kv0, guider_mem_kv1 = [
|
557 |
+
o["mem_kv"][0] for o in guider_output_per_layers
|
558 |
+
], [o["mem_kv"][1] for o in guider_output_per_layers]
|
559 |
+
|
560 |
+
if not mems_buffers_on_GPU:
|
561 |
+
torch.cuda.empty_cache()
|
562 |
+
for idx, mem_buffer in enumerate(mems_buffers):
|
563 |
+
mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
|
564 |
+
if guider_seq is not None:
|
565 |
+
for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
|
566 |
+
guider_mems_buffers[idx] = guider_mem_buffer.to(
|
567 |
+
next(model.parameters()).device
|
568 |
+
)
|
569 |
+
mems_buffers_on_GPU = True
|
570 |
+
|
571 |
+
mems, mems_indexs = my_update_mems(
|
572 |
+
[mem_kv0, mem_kv1],
|
573 |
+
mems_buffers,
|
574 |
+
mems_indexs,
|
575 |
+
limited_spatial_channel_mem,
|
576 |
+
text_len,
|
577 |
+
frame_len,
|
578 |
+
)
|
579 |
+
if guider_seq is not None:
|
580 |
+
guider_mems, guider_mems_indexs = my_update_mems(
|
581 |
+
[guider_mem_kv0, guider_mem_kv1],
|
582 |
+
guider_mems_buffers,
|
583 |
+
guider_mems_indexs,
|
584 |
+
limited_spatial_channel_mem,
|
585 |
+
guider_text_len,
|
586 |
+
frame_len,
|
587 |
+
)
|
588 |
+
|
589 |
+
counter += 1
|
590 |
+
index = counter
|
591 |
+
|
592 |
+
logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
|
593 |
+
tokens = tokens.expand(batch_size, -1)
|
594 |
+
if guider_seq is not None:
|
595 |
+
guider_logits = guider_logits[:, -1].expand(batch_size, -1)
|
596 |
+
guider_tokens = guider_tokens.expand(batch_size, -1)
|
597 |
+
|
598 |
+
if seq[-1][counter].item() < 0:
|
599 |
+
# sampling
|
600 |
+
guided_logits = (
|
601 |
+
guider_logits + (logits - guider_logits) * guidance_alpha
|
602 |
+
if guider_seq is not None
|
603 |
+
else logits
|
604 |
+
)
|
605 |
+
if mode_stage1 and counter < text_len + 400:
|
606 |
+
tokens, mems = strategy.forward(guided_logits, tokens, mems)
|
607 |
+
else:
|
608 |
+
tokens, mems = strategy2.forward(guided_logits, tokens, mems)
|
609 |
+
if guider_seq is not None:
|
610 |
+
guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]), dim=1)
|
611 |
+
|
612 |
+
if seq[0][counter].item() >= 0:
|
613 |
+
for si in range(seq.shape[0]):
|
614 |
+
if seq[si][counter].item() >= 0:
|
615 |
+
tokens[si, -1] = seq[si, counter]
|
616 |
+
if guider_seq is not None:
|
617 |
+
guider_tokens[si, -1] = guider_seq[
|
618 |
+
si, counter - guider_index_delta
|
619 |
+
]
|
620 |
+
|
621 |
+
else:
|
622 |
+
tokens = torch.cat(
|
623 |
+
(
|
624 |
+
tokens,
|
625 |
+
seq[:, counter : counter + 1]
|
626 |
+
.clone()
|
627 |
+
.expand(tokens.shape[0], 1)
|
628 |
+
.to(device=tokens.device, dtype=tokens.dtype),
|
629 |
+
),
|
630 |
+
dim=1,
|
631 |
+
)
|
632 |
+
if guider_seq is not None:
|
633 |
+
guider_tokens = torch.cat(
|
634 |
+
(
|
635 |
+
guider_tokens,
|
636 |
+
guider_seq[
|
637 |
+
:,
|
638 |
+
counter
|
639 |
+
- guider_index_delta : counter
|
640 |
+
+ 1
|
641 |
+
- guider_index_delta,
|
642 |
+
]
|
643 |
+
.clone()
|
644 |
+
.expand(guider_tokens.shape[0], 1)
|
645 |
+
.to(device=guider_tokens.device, dtype=guider_tokens.dtype),
|
646 |
+
),
|
647 |
+
dim=1,
|
648 |
+
)
|
649 |
+
|
650 |
+
input_tokens = tokens.clone()
|
651 |
+
if guider_seq is not None:
|
652 |
+
guider_input_tokens = guider_tokens.clone()
|
653 |
+
if (index - text_len - 1) // 400 < (
|
654 |
+
input_tokens.shape[-1] - text_len - 1
|
655 |
+
) // 400:
|
656 |
+
boi_idx = ((index - text_len - 1) // 400 + 1) * 400 + text_len
|
657 |
+
while boi_idx < input_tokens.shape[-1]:
|
658 |
+
input_tokens[:, boi_idx] = tokenizer["<start_of_image>"]
|
659 |
+
if guider_seq is not None:
|
660 |
+
guider_input_tokens[:, boi_idx - guider_index_delta] = tokenizer[
|
661 |
+
"<start_of_image>"
|
662 |
+
]
|
663 |
+
boi_idx += 400
|
664 |
+
|
665 |
+
if strategy.is_done:
|
666 |
+
break
|
667 |
+
return strategy.finalize(tokens, mems)
|
668 |
+
|
669 |
+
|
670 |
+
class InferenceModel_Sequential(CogVideoCacheModel):
|
671 |
+
def __init__(self, args, transformer=None, parallel_output=True):
|
672 |
+
super().__init__(
|
673 |
+
args,
|
674 |
+
transformer=transformer,
|
675 |
+
parallel_output=parallel_output,
|
676 |
+
window_size=-1,
|
677 |
+
cogvideo_stage=1,
|
678 |
+
)
|
679 |
+
|
680 |
+
# TODO: check it
|
681 |
+
|
682 |
+
def final_forward(self, logits, **kwargs):
|
683 |
+
logits_parallel = logits
|
684 |
+
logits_parallel = torch.nn.functional.linear(
|
685 |
+
logits_parallel.float(),
|
686 |
+
self.transformer.word_embeddings.weight[:20000].float(),
|
687 |
+
)
|
688 |
+
return logits_parallel
|
689 |
+
|
690 |
+
|
691 |
+
class InferenceModel_Interpolate(CogVideoCacheModel):
|
692 |
+
def __init__(self, args, transformer=None, parallel_output=True):
|
693 |
+
super().__init__(
|
694 |
+
args,
|
695 |
+
transformer=transformer,
|
696 |
+
parallel_output=parallel_output,
|
697 |
+
window_size=10,
|
698 |
+
cogvideo_stage=2,
|
699 |
+
)
|
700 |
+
|
701 |
+
# TODO: check it
|
702 |
+
|
703 |
+
def final_forward(self, logits, **kwargs):
|
704 |
+
logits_parallel = logits
|
705 |
+
logits_parallel = torch.nn.functional.linear(
|
706 |
+
logits_parallel.float(),
|
707 |
+
self.transformer.word_embeddings.weight[:20000].float(),
|
708 |
+
)
|
709 |
+
return logits_parallel
|
710 |
+
|
711 |
+
|
712 |
+
def main(args):
|
713 |
+
assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1
|
714 |
+
rank_id = args.device % args.parallel_size
|
715 |
+
generate_frame_num = args.generate_frame_num
|
716 |
+
|
717 |
+
if args.stage_1 or args.both_stages:
|
718 |
+
model_stage1, args = InferenceModel_Sequential.from_pretrained(
|
719 |
+
args, "cogvideo-stage1"
|
720 |
+
)
|
721 |
+
model_stage1.eval()
|
722 |
+
if args.both_stages:
|
723 |
+
model_stage1 = model_stage1.cpu()
|
724 |
+
|
725 |
+
if args.stage_2 or args.both_stages:
|
726 |
+
model_stage2, args = InferenceModel_Interpolate.from_pretrained(
|
727 |
+
args, "cogvideo-stage2"
|
728 |
+
)
|
729 |
+
model_stage2.eval()
|
730 |
+
if args.both_stages:
|
731 |
+
model_stage2 = model_stage2.cpu()
|
732 |
+
|
733 |
+
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
|
734 |
+
strategy_cogview2 = CoglmStrategy(invalid_slices, temperature=1.0, top_k=16)
|
735 |
+
strategy_cogvideo = CoglmStrategy(
|
736 |
+
invalid_slices,
|
737 |
+
temperature=args.temperature,
|
738 |
+
top_k=args.top_k,
|
739 |
+
temperature2=args.coglm_temperature2,
|
740 |
+
)
|
741 |
+
if not args.stage_1:
|
742 |
+
from sr_pipeline import DirectSuperResolution
|
743 |
+
|
744 |
+
dsr_path = auto_create(
|
745 |
+
"cogview2-dsr", path=None
|
746 |
+
) # path=os.getenv('SAT_HOME', '~/.sat_models')
|
747 |
+
dsr = DirectSuperResolution(args, dsr_path, max_bz=12, onCUDA=False)
|
748 |
+
|
749 |
+
def process_stage2(
|
750 |
+
model,
|
751 |
+
seq_text,
|
752 |
+
duration,
|
753 |
+
video_raw_text=None,
|
754 |
+
video_guidance_text="视频",
|
755 |
+
parent_given_tokens=None,
|
756 |
+
conddir=None,
|
757 |
+
outputdir=None,
|
758 |
+
gpu_rank=0,
|
759 |
+
gpu_parallel_size=1,
|
760 |
+
):
|
761 |
+
stage2_starttime = time.time()
|
762 |
+
use_guidance = args.use_guidance_stage2
|
763 |
+
if args.both_stages:
|
764 |
+
move_start_time = time.time()
|
765 |
+
logging.debug("moving stage-2 model to cuda")
|
766 |
+
model = model.cuda()
|
767 |
+
logging.debug(
|
768 |
+
"moving in stage-2 model takes time: {:.2f}".format(
|
769 |
+
time.time() - move_start_time
|
770 |
+
)
|
771 |
+
)
|
772 |
+
|
773 |
+
try:
|
774 |
+
if parent_given_tokens is None:
|
775 |
+
assert conddir is not None
|
776 |
+
parent_given_tokens = torch.load(
|
777 |
+
os.path.join(conddir, "frame_tokens.pt"), map_location="cpu"
|
778 |
+
)
|
779 |
+
sample_num_allgpu = parent_given_tokens.shape[0]
|
780 |
+
sample_num = sample_num_allgpu // gpu_parallel_size
|
781 |
+
assert sample_num * gpu_parallel_size == sample_num_allgpu
|
782 |
+
parent_given_tokens = parent_given_tokens[
|
783 |
+
gpu_rank * sample_num : (gpu_rank + 1) * sample_num
|
784 |
+
]
|
785 |
+
except:
|
786 |
+
logging.critical("No frame_tokens found in interpolation, skip")
|
787 |
+
return False
|
788 |
+
|
789 |
+
# CogVideo Stage2 Generation
|
790 |
+
while (
|
791 |
+
duration >= 0.5
|
792 |
+
): # TODO: You can change the boundary to change the frame rate
|
793 |
+
parent_given_tokens_num = parent_given_tokens.shape[1]
|
794 |
+
generate_batchsize_persample = (parent_given_tokens_num - 1) // 2
|
795 |
+
generate_batchsize_total = generate_batchsize_persample * sample_num
|
796 |
+
total_frames = generate_frame_num
|
797 |
+
frame_len = 400
|
798 |
+
enc_text = tokenizer.encode(seq_text)
|
799 |
+
enc_duration = tokenizer.encode(str(float(duration)) + "秒")
|
800 |
+
seq = (
|
801 |
+
enc_duration
|
802 |
+
+ [tokenizer["<n>"]]
|
803 |
+
+ enc_text
|
804 |
+
+ [tokenizer["<start_of_image>"]]
|
805 |
+
+ [-1] * 400 * generate_frame_num
|
806 |
+
)
|
807 |
+
text_len = len(seq) - frame_len * generate_frame_num - 1
|
808 |
+
|
809 |
+
logging.info(
|
810 |
+
"[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(
|
811 |
+
int(4 / duration), tokenizer.decode(enc_text)
|
812 |
+
)
|
813 |
+
)
|
814 |
+
|
815 |
+
# generation
|
816 |
+
seq = (
|
817 |
+
torch.cuda.LongTensor(seq, device=args.device)
|
818 |
+
.unsqueeze(0)
|
819 |
+
.repeat(generate_batchsize_total, 1)
|
820 |
+
)
|
821 |
+
for sample_i in range(sample_num):
|
822 |
+
for i in range(generate_batchsize_persample):
|
823 |
+
seq[sample_i * generate_batchsize_persample + i][
|
824 |
+
text_len + 1 : text_len + 1 + 400
|
825 |
+
] = parent_given_tokens[sample_i][2 * i]
|
826 |
+
seq[sample_i * generate_batchsize_persample + i][
|
827 |
+
text_len + 1 + 400 : text_len + 1 + 800
|
828 |
+
] = parent_given_tokens[sample_i][2 * i + 1]
|
829 |
+
seq[sample_i * generate_batchsize_persample + i][
|
830 |
+
text_len + 1 + 800 : text_len + 1 + 1200
|
831 |
+
] = parent_given_tokens[sample_i][2 * i + 2]
|
832 |
+
|
833 |
+
if use_guidance:
|
834 |
+
guider_seq = (
|
835 |
+
enc_duration
|
836 |
+
+ [tokenizer["<n>"]]
|
837 |
+
+ tokenizer.encode(video_guidance_text)
|
838 |
+
+ [tokenizer["<start_of_image>"]]
|
839 |
+
+ [-1] * 400 * generate_frame_num
|
840 |
+
)
|
841 |
+
guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
|
842 |
+
guider_seq = (
|
843 |
+
torch.cuda.LongTensor(guider_seq, device=args.device)
|
844 |
+
.unsqueeze(0)
|
845 |
+
.repeat(generate_batchsize_total, 1)
|
846 |
+
)
|
847 |
+
for sample_i in range(sample_num):
|
848 |
+
for i in range(generate_batchsize_persample):
|
849 |
+
guider_seq[sample_i * generate_batchsize_persample + i][
|
850 |
+
text_len + 1 : text_len + 1 + 400
|
851 |
+
] = parent_given_tokens[sample_i][2 * i]
|
852 |
+
guider_seq[sample_i * generate_batchsize_persample + i][
|
853 |
+
text_len + 1 + 400 : text_len + 1 + 800
|
854 |
+
] = parent_given_tokens[sample_i][2 * i + 1]
|
855 |
+
guider_seq[sample_i * generate_batchsize_persample + i][
|
856 |
+
text_len + 1 + 800 : text_len + 1 + 1200
|
857 |
+
] = parent_given_tokens[sample_i][2 * i + 2]
|
858 |
+
video_log_text_attention_weights = 0
|
859 |
+
else:
|
860 |
+
guider_seq = None
|
861 |
+
guider_text_len = 0
|
862 |
+
video_log_text_attention_weights = 1.4
|
863 |
+
|
864 |
+
mbz = args.max_inference_batch_size
|
865 |
+
|
866 |
+
assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
|
867 |
+
output_list = []
|
868 |
+
start_time = time.time()
|
869 |
+
for tim in range(max(generate_batchsize_total // mbz, 1)):
|
870 |
+
input_seq = (
|
871 |
+
seq[: min(generate_batchsize_total, mbz)].clone()
|
872 |
+
if tim == 0
|
873 |
+
else seq[mbz * tim : mbz * (tim + 1)].clone()
|
874 |
+
)
|
875 |
+
guider_seq2 = (
|
876 |
+
(
|
877 |
+
guider_seq[: min(generate_batchsize_total, mbz)].clone()
|
878 |
+
if tim == 0
|
879 |
+
else guider_seq[mbz * tim : mbz * (tim + 1)].clone()
|
880 |
+
)
|
881 |
+
if guider_seq is not None
|
882 |
+
else None
|
883 |
+
)
|
884 |
+
output_list.append(
|
885 |
+
my_filling_sequence(
|
886 |
+
model,
|
887 |
+
args,
|
888 |
+
input_seq,
|
889 |
+
batch_size=min(generate_batchsize_total, mbz),
|
890 |
+
get_masks_and_position_ids=get_masks_and_position_ids_stage2,
|
891 |
+
text_len=text_len,
|
892 |
+
frame_len=frame_len,
|
893 |
+
strategy=strategy_cogview2,
|
894 |
+
strategy2=strategy_cogvideo,
|
895 |
+
log_text_attention_weights=video_log_text_attention_weights,
|
896 |
+
mode_stage1=False,
|
897 |
+
guider_seq=guider_seq2,
|
898 |
+
guider_text_len=guider_text_len,
|
899 |
+
guidance_alpha=args.guidance_alpha,
|
900 |
+
limited_spatial_channel_mem=True,
|
901 |
+
)[0]
|
902 |
+
)
|
903 |
+
logging.info(
|
904 |
+
"Duration {:.2f}, Taken time {:.2f}\n".format(
|
905 |
+
duration, time.time() - start_time
|
906 |
+
)
|
907 |
+
)
|
908 |
+
|
909 |
+
output_tokens = torch.cat(output_list, dim=0)
|
910 |
+
output_tokens = output_tokens[
|
911 |
+
:, text_len + 1 : text_len + 1 + (total_frames) * 400
|
912 |
+
].reshape(sample_num, -1, 400 * total_frames)
|
913 |
+
output_tokens_merge = torch.cat(
|
914 |
+
(
|
915 |
+
output_tokens[:, :, : 1 * 400],
|
916 |
+
output_tokens[:, :, 400 * 3 : 4 * 400],
|
917 |
+
output_tokens[:, :, 400 * 1 : 2 * 400],
|
918 |
+
output_tokens[:, :, 400 * 4 : (total_frames) * 400],
|
919 |
+
),
|
920 |
+
dim=2,
|
921 |
+
).reshape(sample_num, -1, 400)
|
922 |
+
|
923 |
+
output_tokens_merge = torch.cat(
|
924 |
+
(output_tokens_merge, output_tokens[:, -1:, 400 * 2 : 3 * 400]), dim=1
|
925 |
+
)
|
926 |
+
duration /= 2
|
927 |
+
parent_given_tokens = output_tokens_merge
|
928 |
+
|
929 |
+
if args.both_stages:
|
930 |
+
move_start_time = time.time()
|
931 |
+
logging.debug("moving stage 2 model to cpu")
|
932 |
+
model = model.cpu()
|
933 |
+
torch.cuda.empty_cache()
|
934 |
+
logging.debug(
|
935 |
+
"moving out model2 takes time: {:.2f}".format(
|
936 |
+
time.time() - move_start_time
|
937 |
+
)
|
938 |
+
)
|
939 |
+
|
940 |
+
logging.info(
|
941 |
+
"CogVideo Stage2 completed. Taken time {:.2f}\n".format(
|
942 |
+
time.time() - stage2_starttime
|
943 |
+
)
|
944 |
+
)
|
945 |
+
|
946 |
+
# decoding
|
947 |
+
# imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
|
948 |
+
# os.makedirs(output_dir_full_path, exist_ok=True)
|
949 |
+
# my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False)
|
950 |
+
# torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt'))
|
951 |
+
# os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2")
|
952 |
+
|
953 |
+
# direct super-resolution by CogView2
|
954 |
+
logging.info("[Direct super-resolution]")
|
955 |
+
dsr_starttime = time.time()
|
956 |
+
enc_text = tokenizer.encode(seq_text)
|
957 |
+
frame_num_per_sample = parent_given_tokens.shape[1]
|
958 |
+
parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
|
959 |
+
text_seq = (
|
960 |
+
torch.cuda.LongTensor(enc_text, device=args.device)
|
961 |
+
.unsqueeze(0)
|
962 |
+
.repeat(parent_given_tokens_2d.shape[0], 1)
|
963 |
+
)
|
964 |
+
sred_tokens = dsr(text_seq, parent_given_tokens_2d)
|
965 |
+
decoded_sr_videos = []
|
966 |
+
|
967 |
+
for sample_i in range(sample_num):
|
968 |
+
decoded_sr_imgs = []
|
969 |
+
for frame_i in range(frame_num_per_sample):
|
970 |
+
decoded_sr_img = tokenizer.decode(
|
971 |
+
image_ids=sred_tokens[frame_i + sample_i * frame_num_per_sample][
|
972 |
+
-3600:
|
973 |
+
]
|
974 |
+
)
|
975 |
+
decoded_sr_imgs.append(
|
976 |
+
torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480))
|
977 |
+
)
|
978 |
+
decoded_sr_videos.append(decoded_sr_imgs)
|
979 |
+
|
980 |
+
for sample_i in range(sample_num):
|
981 |
+
my_save_multiple_images(
|
982 |
+
decoded_sr_videos[sample_i],
|
983 |
+
outputdir,
|
984 |
+
subdir=f"frames/{sample_i+sample_num*gpu_rank}",
|
985 |
+
debug=False,
|
986 |
+
)
|
987 |
+
os.system(
|
988 |
+
f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125"
|
989 |
+
)
|
990 |
+
|
991 |
+
logging.info(
|
992 |
+
"Direct super-resolution completed. Taken time {:.2f}\n".format(
|
993 |
+
time.time() - dsr_starttime
|
994 |
+
)
|
995 |
+
)
|
996 |
+
|
997 |
+
return True
|
998 |
+
|
999 |
+
def process_stage1(
|
1000 |
+
model,
|
1001 |
+
seq_text,
|
1002 |
+
duration,
|
1003 |
+
video_raw_text=None,
|
1004 |
+
video_guidance_text="视频",
|
1005 |
+
image_text_suffix="",
|
1006 |
+
outputdir=None,
|
1007 |
+
batch_size=1,
|
1008 |
+
):
|
1009 |
+
process_start_time = time.time()
|
1010 |
+
use_guide = args.use_guidance_stage1
|
1011 |
+
if args.both_stages:
|
1012 |
+
move_start_time = time.time()
|
1013 |
+
logging.debug("moving stage 1 model to cuda")
|
1014 |
+
model = model.cuda()
|
1015 |
+
logging.debug(
|
1016 |
+
"moving in model1 takes time: {:.2f}".format(
|
1017 |
+
time.time() - move_start_time
|
1018 |
+
)
|
1019 |
+
)
|
1020 |
+
|
1021 |
+
if video_raw_text is None:
|
1022 |
+
video_raw_text = seq_text
|
1023 |
+
mbz = (
|
1024 |
+
args.stage1_max_inference_batch_size
|
1025 |
+
if args.stage1_max_inference_batch_size > 0
|
1026 |
+
else args.max_inference_batch_size
|
1027 |
+
)
|
1028 |
+
assert batch_size < mbz or batch_size % mbz == 0
|
1029 |
+
frame_len = 400
|
1030 |
+
|
1031 |
+
# generate the first frame:
|
1032 |
+
enc_text = tokenizer.encode(seq_text + image_text_suffix)
|
1033 |
+
seq_1st = (
|
1034 |
+
enc_text + [tokenizer["<start_of_image>"]] + [-1] * 400
|
1035 |
+
) # IV!! # test local!!! # test randboi!!!
|
1036 |
+
logging.info(
|
1037 |
+
"[Generating First Frame with CogView2]Raw text: {:s}".format(
|
1038 |
+
tokenizer.decode(enc_text)
|
1039 |
+
)
|
1040 |
+
)
|
1041 |
+
text_len_1st = len(seq_1st) - frame_len * 1 - 1
|
1042 |
+
|
1043 |
+
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
|
1044 |
+
output_list_1st = []
|
1045 |
+
for tim in range(max(batch_size // mbz, 1)):
|
1046 |
+
start_time = time.time()
|
1047 |
+
output_list_1st.append(
|
1048 |
+
my_filling_sequence(
|
1049 |
+
model,
|
1050 |
+
args,
|
1051 |
+
seq_1st.clone(),
|
1052 |
+
batch_size=min(batch_size, mbz),
|
1053 |
+
get_masks_and_position_ids=get_masks_and_position_ids_stage1,
|
1054 |
+
text_len=text_len_1st,
|
1055 |
+
frame_len=frame_len,
|
1056 |
+
strategy=strategy_cogview2,
|
1057 |
+
strategy2=strategy_cogvideo,
|
1058 |
+
log_text_attention_weights=1.4,
|
1059 |
+
enforce_no_swin=True,
|
1060 |
+
mode_stage1=True,
|
1061 |
+
)[0]
|
1062 |
+
)
|
1063 |
+
logging.info(
|
1064 |
+
"[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)
|
1065 |
+
)
|
1066 |
+
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
1067 |
+
given_tokens = output_tokens_1st[
|
1068 |
+
:, text_len_1st + 1 : text_len_1st + 401
|
1069 |
+
].unsqueeze(
|
1070 |
+
1
|
1071 |
+
) # given_tokens.shape: [bs, frame_num, 400]
|
1072 |
+
|
1073 |
+
# generate subsequent frames:
|
1074 |
+
total_frames = generate_frame_num
|
1075 |
+
enc_duration = tokenizer.encode(str(float(duration)) + "秒")
|
1076 |
+
if use_guide:
|
1077 |
+
video_raw_text = video_raw_text + " 视频"
|
1078 |
+
enc_text_video = tokenizer.encode(video_raw_text)
|
1079 |
+
seq = (
|
1080 |
+
enc_duration
|
1081 |
+
+ [tokenizer["<n>"]]
|
1082 |
+
+ enc_text_video
|
1083 |
+
+ [tokenizer["<start_of_image>"]]
|
1084 |
+
+ [-1] * 400 * generate_frame_num
|
1085 |
+
)
|
1086 |
+
guider_seq = (
|
1087 |
+
enc_duration
|
1088 |
+
+ [tokenizer["<n>"]]
|
1089 |
+
+ tokenizer.encode(video_guidance_text)
|
1090 |
+
+ [tokenizer["<start_of_image>"]]
|
1091 |
+
+ [-1] * 400 * generate_frame_num
|
1092 |
+
)
|
1093 |
+
logging.info(
|
1094 |
+
"[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(
|
1095 |
+
4 / duration, tokenizer.decode(enc_text_video)
|
1096 |
+
)
|
1097 |
+
)
|
1098 |
+
|
1099 |
+
text_len = len(seq) - frame_len * generate_frame_num - 1
|
1100 |
+
guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
|
1101 |
+
seq = (
|
1102 |
+
torch.cuda.LongTensor(seq, device=args.device)
|
1103 |
+
.unsqueeze(0)
|
1104 |
+
.repeat(batch_size, 1)
|
1105 |
+
)
|
1106 |
+
guider_seq = (
|
1107 |
+
torch.cuda.LongTensor(guider_seq, device=args.device)
|
1108 |
+
.unsqueeze(0)
|
1109 |
+
.repeat(batch_size, 1)
|
1110 |
+
)
|
1111 |
+
|
1112 |
+
for given_frame_id in range(given_tokens.shape[1]):
|
1113 |
+
seq[
|
1114 |
+
:,
|
1115 |
+
text_len
|
1116 |
+
+ 1
|
1117 |
+
+ given_frame_id * 400 : text_len
|
1118 |
+
+ 1
|
1119 |
+
+ (given_frame_id + 1) * 400,
|
1120 |
+
] = given_tokens[:, given_frame_id]
|
1121 |
+
guider_seq[
|
1122 |
+
:,
|
1123 |
+
guider_text_len
|
1124 |
+
+ 1
|
1125 |
+
+ given_frame_id * 400 : guider_text_len
|
1126 |
+
+ 1
|
1127 |
+
+ (given_frame_id + 1) * 400,
|
1128 |
+
] = given_tokens[:, given_frame_id]
|
1129 |
+
output_list = []
|
1130 |
+
|
1131 |
+
if use_guide:
|
1132 |
+
video_log_text_attention_weights = 0
|
1133 |
+
else:
|
1134 |
+
guider_seq = None
|
1135 |
+
video_log_text_attention_weights = 1.4
|
1136 |
+
|
1137 |
+
for tim in range(max(batch_size // mbz, 1)):
|
1138 |
+
start_time = time.time()
|
1139 |
+
input_seq = (
|
1140 |
+
seq[: min(batch_size, mbz)].clone()
|
1141 |
+
if tim == 0
|
1142 |
+
else seq[mbz * tim : mbz * (tim + 1)].clone()
|
1143 |
+
)
|
1144 |
+
guider_seq2 = (
|
1145 |
+
(
|
1146 |
+
guider_seq[: min(batch_size, mbz)].clone()
|
1147 |
+
if tim == 0
|
1148 |
+
else guider_seq[mbz * tim : mbz * (tim + 1)].clone()
|
1149 |
+
)
|
1150 |
+
if guider_seq is not None
|
1151 |
+
else None
|
1152 |
+
)
|
1153 |
+
output_list.append(
|
1154 |
+
my_filling_sequence(
|
1155 |
+
model,
|
1156 |
+
args,
|
1157 |
+
input_seq,
|
1158 |
+
batch_size=min(batch_size, mbz),
|
1159 |
+
get_masks_and_position_ids=get_masks_and_position_ids_stage1,
|
1160 |
+
text_len=text_len,
|
1161 |
+
frame_len=frame_len,
|
1162 |
+
strategy=strategy_cogview2,
|
1163 |
+
strategy2=strategy_cogvideo,
|
1164 |
+
log_text_attention_weights=video_log_text_attention_weights,
|
1165 |
+
guider_seq=guider_seq2,
|
1166 |
+
guider_text_len=guider_text_len,
|
1167 |
+
guidance_alpha=args.guidance_alpha,
|
1168 |
+
limited_spatial_channel_mem=True,
|
1169 |
+
mode_stage1=True,
|
1170 |
+
)[0]
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len :]
|
1174 |
+
|
1175 |
+
if args.both_stages:
|
1176 |
+
move_start_time = time.time()
|
1177 |
+
logging.debug("moving stage 1 model to cpu")
|
1178 |
+
model = model.cpu()
|
1179 |
+
torch.cuda.empty_cache()
|
1180 |
+
logging.debug(
|
1181 |
+
"moving in model1 takes time: {:.2f}".format(
|
1182 |
+
time.time() - move_start_time
|
1183 |
+
)
|
1184 |
+
)
|
1185 |
+
|
1186 |
+
# decoding
|
1187 |
+
imgs, sred_imgs, txts = [], [], []
|
1188 |
+
for seq in output_tokens:
|
1189 |
+
decoded_imgs = [
|
1190 |
+
torch.nn.functional.interpolate(
|
1191 |
+
tokenizer.decode(image_ids=seq.tolist()[i * 400 : (i + 1) * 400]),
|
1192 |
+
size=(480, 480),
|
1193 |
+
)
|
1194 |
+
for i in range(total_frames)
|
1195 |
+
]
|
1196 |
+
imgs.append(decoded_imgs) # only the last image (target)
|
1197 |
+
|
1198 |
+
assert len(imgs) == batch_size
|
1199 |
+
save_tokens = (
|
1200 |
+
output_tokens[:, : +total_frames * 400].reshape(-1, total_frames, 400).cpu()
|
1201 |
+
)
|
1202 |
+
if outputdir is not None:
|
1203 |
+
for clip_i in range(len(imgs)):
|
1204 |
+
# os.makedirs(output_dir_full_paths[clip_i], exist_ok=True)
|
1205 |
+
my_save_multiple_images(
|
1206 |
+
imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False
|
1207 |
+
)
|
1208 |
+
os.system(
|
1209 |
+
f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25"
|
1210 |
+
)
|
1211 |
+
torch.save(save_tokens, os.path.join(outputdir, "frame_tokens.pt"))
|
1212 |
+
|
1213 |
+
logging.info(
|
1214 |
+
"CogVideo Stage1 completed. Taken time {:.2f}\n".format(
|
1215 |
+
time.time() - process_start_time
|
1216 |
+
)
|
1217 |
+
)
|
1218 |
+
|
1219 |
+
return save_tokens
|
1220 |
+
|
1221 |
+
# ======================================================================================================
|
1222 |
+
|
1223 |
+
if args.stage_1 or args.both_stages:
|
1224 |
+
if args.input_source != "interactive":
|
1225 |
+
with open(args.input_source, "r") as fin:
|
1226 |
+
promptlist = fin.readlines()
|
1227 |
+
promptlist = [p.strip() for p in promptlist]
|
1228 |
+
else:
|
1229 |
+
promptlist = None
|
1230 |
+
|
1231 |
+
now_qi = -1
|
1232 |
+
while True:
|
1233 |
+
now_qi += 1
|
1234 |
+
|
1235 |
+
if promptlist is not None: # with input-source
|
1236 |
+
if args.multi_gpu:
|
1237 |
+
if now_qi % dist.get_world_size() != dist.get_rank():
|
1238 |
+
continue
|
1239 |
+
rk = dist.get_rank()
|
1240 |
+
else:
|
1241 |
+
rk = 0
|
1242 |
+
raw_text = promptlist[now_qi]
|
1243 |
+
raw_text = raw_text.strip()
|
1244 |
+
print(f"Working on Line No. {now_qi} on {rk}... [{raw_text}]")
|
1245 |
+
else: # interactive
|
1246 |
+
raw_text = input("\nPlease Input Query (stop to exit) >>> ")
|
1247 |
+
raw_text = raw_text.strip()
|
1248 |
+
if not raw_text:
|
1249 |
+
print("Query should not be empty!")
|
1250 |
+
continue
|
1251 |
+
if raw_text == "stop":
|
1252 |
+
return
|
1253 |
+
|
1254 |
+
try:
|
1255 |
+
path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
|
1256 |
+
parent_given_tokens = process_stage1(
|
1257 |
+
model_stage1,
|
1258 |
+
raw_text,
|
1259 |
+
duration=4.0,
|
1260 |
+
video_raw_text=raw_text,
|
1261 |
+
video_guidance_text="视频",
|
1262 |
+
image_text_suffix=" 高清摄影",
|
1263 |
+
outputdir=path if args.stage_1 else None,
|
1264 |
+
batch_size=args.batch_size,
|
1265 |
+
)
|
1266 |
+
if args.both_stages:
|
1267 |
+
process_stage2(
|
1268 |
+
model_stage2,
|
1269 |
+
raw_text,
|
1270 |
+
duration=2.0,
|
1271 |
+
video_raw_text=raw_text + " 视频",
|
1272 |
+
video_guidance_text="视频",
|
1273 |
+
parent_given_tokens=parent_given_tokens,
|
1274 |
+
outputdir=path,
|
1275 |
+
gpu_rank=0,
|
1276 |
+
gpu_parallel_size=1,
|
1277 |
+
) # TODO: 修改
|
1278 |
+
except (ValueError, FileNotFoundError) as e:
|
1279 |
+
print(e)
|
1280 |
+
continue
|
1281 |
+
|
1282 |
+
elif args.stage_2:
|
1283 |
+
sample_dirs = os.listdir(args.output_path)
|
1284 |
+
for sample in sample_dirs:
|
1285 |
+
raw_text = sample.split("_")[-1]
|
1286 |
+
path = os.path.join(args.output_path, sample, "Interp")
|
1287 |
+
parent_given_tokens = torch.load(
|
1288 |
+
os.path.join(args.output_path, sample, "frame_tokens.pt")
|
1289 |
+
)
|
1290 |
+
|
1291 |
+
process_stage2(
|
1292 |
+
raw_text,
|
1293 |
+
duration=2.0,
|
1294 |
+
video_raw_text=raw_text + " 视频",
|
1295 |
+
video_guidance_text="视频",
|
1296 |
+
parent_given_tokens=parent_given_tokens,
|
1297 |
+
outputdir=path,
|
1298 |
+
gpu_rank=0,
|
1299 |
+
gpu_parallel_size=1,
|
1300 |
+
) # TODO: 修改
|
1301 |
+
|
1302 |
+
else:
|
1303 |
+
assert False
|
1304 |
+
|
1305 |
+
|
1306 |
+
if __name__ == "__main__":
|
1307 |
+
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
|
1308 |
+
|
1309 |
+
py_parser = argparse.ArgumentParser(add_help=False)
|
1310 |
+
py_parser.add_argument("--generate-frame-num", type=int, default=5)
|
1311 |
+
py_parser.add_argument("--coglm-temperature2", type=float, default=0.89)
|
1312 |
+
# py_parser.add_argument("--interp-duration", type=float, default=-1) # -1是顺序生成,0是超分,0.5/1/2是插帧
|
1313 |
+
# py_parser.add_argument("--total-duration", type=float, default=4.0) # 整个的时间
|
1314 |
+
py_parser.add_argument("--use-guidance-stage1", action="store_true")
|
1315 |
+
py_parser.add_argument("--use-guidance-stage2", action="store_true")
|
1316 |
+
py_parser.add_argument("--guidance-alpha", type=float, default=3.0)
|
1317 |
+
py_parser.add_argument(
|
1318 |
+
"--stage-1", action="store_true"
|
1319 |
+
) # stage 1: sequential generation
|
1320 |
+
py_parser.add_argument("--stage-2", action="store_true") # stage 2: interp + dsr
|
1321 |
+
py_parser.add_argument(
|
1322 |
+
"--both-stages", action="store_true"
|
1323 |
+
) # stage 1&2: sequential generation; interp + dsr
|
1324 |
+
py_parser.add_argument("--parallel-size", type=int, default=1)
|
1325 |
+
py_parser.add_argument(
|
1326 |
+
"--stage1-max-inference-batch-size", type=int, default=-1
|
1327 |
+
) # -1: use max-inference-batch-size
|
1328 |
+
py_parser.add_argument("--multi-gpu", action="store_true")
|
1329 |
+
|
1330 |
+
CogVideoCacheModel.add_model_specific_args(py_parser)
|
1331 |
+
|
1332 |
+
known, args_list = py_parser.parse_known_args()
|
1333 |
+
args = get_args(args_list)
|
1334 |
+
args = argparse.Namespace(**vars(args), **vars(known))
|
1335 |
+
args.layout = [int(x) for x in args.layout.split(",")]
|
1336 |
+
args.do_train = False
|
1337 |
+
|
1338 |
+
torch.cuda.set_device(args.device)
|
1339 |
+
|
1340 |
+
with torch.no_grad():
|
1341 |
+
main(args)
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_cache_model.py
ADDED
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : cogvideo_cache_model.py
|
4 |
+
@Time : 2022/07/15 11:22:19
|
5 |
+
@Author : Wenyi Hong
|
6 |
+
@Version : 1.0
|
7 |
+
@Contact : [email protected]
|
8 |
+
'''
|
9 |
+
|
10 |
+
# here put the import lib
|
11 |
+
|
12 |
+
from multiprocessing import context
|
13 |
+
from tkinter import E
|
14 |
+
import torch
|
15 |
+
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
|
16 |
+
|
17 |
+
from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
|
18 |
+
from SwissArmyTransformer.model.transformer import unscaled_init_method
|
19 |
+
from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
|
22 |
+
import math
|
23 |
+
|
24 |
+
|
25 |
+
class PositionEmbeddingMixin(BaseMixin):
|
26 |
+
def __init__(self, additional_sequence_length, hidden_size,
|
27 |
+
init_method_std=0.02, reinit_slice=slice(512, 912),
|
28 |
+
):
|
29 |
+
super(PositionEmbeddingMixin, self).__init__()
|
30 |
+
self.reinit_slice = reinit_slice
|
31 |
+
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
|
32 |
+
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
|
33 |
+
|
34 |
+
def reinit(self, parent_model=None):
|
35 |
+
old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
|
36 |
+
old_len, hidden_size = old_weights.shape
|
37 |
+
assert hidden_size == self.position_embeddings.weight.shape[-1]
|
38 |
+
self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
|
39 |
+
|
40 |
+
|
41 |
+
def window_partition(x, window_size):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
x: (B, framenum, H, W, C)
|
45 |
+
window_size (int): window size
|
46 |
+
Returns:
|
47 |
+
windows: (num_windows*B, frame_num, window_size, window_size, C)
|
48 |
+
"""
|
49 |
+
B, framenum, H, W, C = x.shape
|
50 |
+
x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
|
51 |
+
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
|
52 |
+
return windows
|
53 |
+
|
54 |
+
def window_reverse(windows, window_size, H, W):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
windows: (num_windows*B, frame_num, window_size, window_size, C)
|
58 |
+
window_size (int): Window size
|
59 |
+
H (int): Height of image
|
60 |
+
W (int): Width of image
|
61 |
+
Returns:
|
62 |
+
x: (B, frame_num, H, W, C)
|
63 |
+
"""
|
64 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
65 |
+
framenum = windows.shape[1]
|
66 |
+
x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
|
67 |
+
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
|
68 |
+
return x
|
69 |
+
|
70 |
+
class WindowAttentionMixin(BaseMixin):
|
71 |
+
def __init__(self, num_layers,
|
72 |
+
hidden_size,
|
73 |
+
frame_resolution,
|
74 |
+
window_size,
|
75 |
+
shift_size,
|
76 |
+
n_head,
|
77 |
+
frame_num,
|
78 |
+
init_method=unscaled_init_method(0.02),
|
79 |
+
output_layer_init_method=unscaled_init_method(0.02),
|
80 |
+
time_dim_attend_length=0
|
81 |
+
):
|
82 |
+
super(WindowAttentionMixin, self).__init__()
|
83 |
+
self.num_layers = num_layers # replace attention in the LAST n layers
|
84 |
+
self.query_key_value = torch.nn.ModuleList(
|
85 |
+
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
|
86 |
+
gather_output=False,init_method=init_method)
|
87 |
+
for layer_id in range(num_layers)
|
88 |
+
])
|
89 |
+
self.dense = torch.nn.ModuleList(
|
90 |
+
[RowParallelLinear(
|
91 |
+
hidden_size,
|
92 |
+
hidden_size,
|
93 |
+
input_is_parallel=True,
|
94 |
+
init_method=output_layer_init_method,
|
95 |
+
bias=True,
|
96 |
+
module=self,
|
97 |
+
name="dense")
|
98 |
+
for layer_id in range(num_layers)
|
99 |
+
])
|
100 |
+
|
101 |
+
self.n_head = n_head
|
102 |
+
self.window_size = window_size
|
103 |
+
self.frame_resolution = frame_resolution
|
104 |
+
self.frame_len = frame_resolution * frame_resolution
|
105 |
+
self.time_dim_attend_length = time_dim_attend_length
|
106 |
+
assert frame_resolution % window_size == 0
|
107 |
+
assert 0 < shift_size < window_size
|
108 |
+
nW = (self.frame_resolution // self.window_size) ** 2
|
109 |
+
ws_squre = self.window_size * self.window_size
|
110 |
+
|
111 |
+
# odd non-shift, even shift
|
112 |
+
img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
|
113 |
+
h_slices = (slice(0, -shift_size),
|
114 |
+
slice(-shift_size, None))
|
115 |
+
w_slices = (slice(0, -shift_size),
|
116 |
+
slice(-shift_size, None))
|
117 |
+
cnt = 0
|
118 |
+
for h in h_slices:
|
119 |
+
for w in w_slices:
|
120 |
+
img_mask[:, :, h, w, :] = cnt
|
121 |
+
cnt += 1
|
122 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
|
123 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
124 |
+
sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
|
125 |
+
sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
|
126 |
+
attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
|
127 |
+
attn_mask = attn_mask.tril()
|
128 |
+
|
129 |
+
causal_mask = torch.ones(ws_squre*frame_num, ws_squre*frame_num)
|
130 |
+
causal_mask = causal_mask.tril()
|
131 |
+
|
132 |
+
self.shift_sizes = [0, shift_size]
|
133 |
+
self.attn_mask = attn_mask
|
134 |
+
self.causal_mask = causal_mask
|
135 |
+
self.mask_initialized = False
|
136 |
+
|
137 |
+
self.attn_distribution = torch.nn.ParameterList([
|
138 |
+
torch.nn.Parameter(torch.zeros(hidden_size))
|
139 |
+
for _ in range(num_layers)
|
140 |
+
])
|
141 |
+
|
142 |
+
def reinit(self, *pre_mixins):
|
143 |
+
start_layer = len(self.transformer.layers) - self.num_layers
|
144 |
+
assert start_layer >= 0
|
145 |
+
for layer_id in range(self.num_layers):
|
146 |
+
old_attention = self.transformer.layers[start_layer + layer_id].attention
|
147 |
+
self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
|
148 |
+
self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
|
149 |
+
|
150 |
+
def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
|
151 |
+
# frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
|
152 |
+
if not self.mask_initialized:
|
153 |
+
self.attn_mask = self.attn_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
154 |
+
self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
155 |
+
self.mask_initialized = True
|
156 |
+
b0, s1, h0 = frame_hidden_state.shape
|
157 |
+
h = h0 // self.n_head
|
158 |
+
frame_len = self.frame_resolution * self.frame_resolution
|
159 |
+
frame_num = s1 // frame_len
|
160 |
+
if stage == 2:
|
161 |
+
assert frame_num == 3
|
162 |
+
assert frame_num*frame_len == s1
|
163 |
+
wind_square = self.window_size * self.window_size
|
164 |
+
nW = frame_len // wind_square
|
165 |
+
bswin = b0 * nW
|
166 |
+
|
167 |
+
if memkv_text is not None:
|
168 |
+
s0 = memkv_text.shape[-2]
|
169 |
+
k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
|
170 |
+
v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
|
171 |
+
|
172 |
+
# shift
|
173 |
+
frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
|
174 |
+
if self.shift_sizes[layer_id%2] > 0:
|
175 |
+
frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
|
176 |
+
# window partition
|
177 |
+
frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
|
178 |
+
qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
|
179 |
+
.permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
|
180 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
181 |
+
attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
|
182 |
+
|
183 |
+
if stage == 1:
|
184 |
+
if self.shift_sizes[layer_id%2] > 0:
|
185 |
+
attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square),
|
186 |
+
self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))\
|
187 |
+
- 10000.0 * (1.0 - self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))
|
188 |
+
attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
|
189 |
+
else:
|
190 |
+
attn = torch.mul(attn, self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))\
|
191 |
+
- 10000.0 * (1.0 - self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))
|
192 |
+
|
193 |
+
if memkv_text is None:
|
194 |
+
attn = F.softmax(attn, dim=-1)
|
195 |
+
if attn_dropout is not None:
|
196 |
+
with get_cuda_rng_tracker().fork():
|
197 |
+
attn = attn_dropout(attn)
|
198 |
+
context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
199 |
+
else:
|
200 |
+
attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
|
201 |
+
attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
|
202 |
+
attn = torch.cat((attn, attn_frame2text), dim=-1)
|
203 |
+
attn = F.softmax(attn, dim=-1)
|
204 |
+
|
205 |
+
if attn_dropout is not None:
|
206 |
+
with get_cuda_rng_tracker().fork():
|
207 |
+
attn = attn_dropout(attn)
|
208 |
+
|
209 |
+
context_swin = (torch.matmul(attn[..., :-s0], v) +
|
210 |
+
torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
|
211 |
+
.reshape(bswin, self.n_head, frame_num*wind_square, h))\
|
212 |
+
.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
213 |
+
|
214 |
+
context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
|
215 |
+
|
216 |
+
# reverse cycle shift
|
217 |
+
if self.shift_sizes[layer_id%2] > 0:
|
218 |
+
context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
|
219 |
+
ret_context = context_swin.reshape(b0, s1, h0)
|
220 |
+
|
221 |
+
# for mem
|
222 |
+
memk = k.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
223 |
+
memv = v.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
224 |
+
memk = window_reverse(memk, self.window_size, self.frame_resolution, self.frame_resolution)
|
225 |
+
memv = window_reverse(memv, self.window_size, self.frame_resolution, self.frame_resolution)
|
226 |
+
if self.shift_sizes[layer_id%2] > 0:
|
227 |
+
memk = torch.roll(memk, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
|
228 |
+
memv = torch.roll(memv, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
|
229 |
+
memk, memv = memk.reshape(b0, s1, h0), memv.reshape(b0, s1, h0)
|
230 |
+
|
231 |
+
ret_mem = torch.cat((memk, memv), dim=-1)
|
232 |
+
return ret_context, ret_mem
|
233 |
+
|
234 |
+
def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
|
235 |
+
# frame_hidden_state [batchsize, 1, n_head*hiddensize_perhead]
|
236 |
+
# memkv [batchsize, pos, hidden_size*2] (include frames only)
|
237 |
+
# if memkv_text is not None: will attend to text
|
238 |
+
# pos: token's pos
|
239 |
+
b0, sin, h0 = frame_hidden_state.shape
|
240 |
+
h = h0 // self.n_head
|
241 |
+
assert sin == 1
|
242 |
+
this_qkv = self.query_key_value[layer_id](frame_hidden_state)
|
243 |
+
thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
|
244 |
+
s1 = memkv.shape[1] if memkv is not None else 0
|
245 |
+
frame_len = self.frame_resolution * self.frame_resolution
|
246 |
+
frame_num_before = s1 // frame_len
|
247 |
+
|
248 |
+
|
249 |
+
if memkv is not None:
|
250 |
+
pos_inframe = pos - frame_num_before * frame_len
|
251 |
+
|
252 |
+
xpos = pos_inframe // self.frame_resolution # pos = xpos*self.frame_resolution + ypos
|
253 |
+
ypos = pos_inframe % self.frame_resolution
|
254 |
+
# [start, end)
|
255 |
+
if self.shift_sizes[layer_id%2] > 0:
|
256 |
+
xstart = ((xpos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
|
257 |
+
ystart = ((ypos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
|
258 |
+
xend = xstart + self.window_size
|
259 |
+
yend = ystart + self.window_size
|
260 |
+
xstart, ystart = max(0, xstart), max(0, ystart)
|
261 |
+
xend, yend = min(xend, self.frame_resolution), min(yend, self.frame_resolution)
|
262 |
+
else:
|
263 |
+
xstart = (xpos // self.window_size) * self.window_size
|
264 |
+
ystart = (ypos // self.window_size) * self.window_size
|
265 |
+
xend, yend = xstart + self.window_size, ystart+self.window_size
|
266 |
+
|
267 |
+
# select index
|
268 |
+
selected_index = list()
|
269 |
+
if frame_num_before > 0:
|
270 |
+
# frames before
|
271 |
+
frame_attended_start = max(0, frame_num_before-self.time_dim_attend_length+1) if self.time_dim_attend_length > 0 else 0
|
272 |
+
for x in range(xstart, xend):
|
273 |
+
for y in range(ystart, yend):
|
274 |
+
selected_index.append(x*self.frame_resolution+y+frame_len*frame_attended_start)
|
275 |
+
cnt_per_frame = len(selected_index)
|
276 |
+
for _ in range((frame_num_before-frame_attended_start-1)*cnt_per_frame):
|
277 |
+
selected_index.append(selected_index[-cnt_per_frame]+frame_len)
|
278 |
+
|
279 |
+
# the last frame
|
280 |
+
for x in range(xstart, xend):
|
281 |
+
for y in range(ystart, yend):
|
282 |
+
tmppos = x*self.frame_resolution+y + frame_num_before * frame_len
|
283 |
+
if tmppos < pos:
|
284 |
+
selected_index.append(tmppos)
|
285 |
+
else:
|
286 |
+
break
|
287 |
+
cnt_all = len(selected_index)+1
|
288 |
+
selected_index = torch.tensor(selected_index, device=memkv.device)
|
289 |
+
used_memkv = torch.index_select(memkv, 1, selected_index)
|
290 |
+
used_k, used_v = used_memkv[..., :h0], used_memkv[..., h0:]
|
291 |
+
used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
|
292 |
+
used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
|
293 |
+
if memkv_text is not None:
|
294 |
+
cnt_all += memkv_text.shape[-2]
|
295 |
+
used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
|
296 |
+
used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
|
297 |
+
used_k = used_k.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
|
298 |
+
used_v = used_v.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
|
299 |
+
else:
|
300 |
+
used_k = thisk
|
301 |
+
used_v = thisv
|
302 |
+
|
303 |
+
if memkv_text is not None:
|
304 |
+
used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
|
305 |
+
used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
|
306 |
+
used_k = used_k.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
|
307 |
+
used_v = used_v.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
|
308 |
+
else:
|
309 |
+
used_k = used_k.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
|
310 |
+
used_v = used_v.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
|
311 |
+
|
312 |
+
thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
|
313 |
+
attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
|
314 |
+
if memkv_text is not None:
|
315 |
+
attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
|
316 |
+
attn = F.softmax(attn, dim=-1)
|
317 |
+
context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
|
318 |
+
|
319 |
+
return context_swin, this_qkv[..., h0:]
|
320 |
+
|
321 |
+
class FullAttentionMixin(BaseMixin):
|
322 |
+
def __init__(self, num_layers,
|
323 |
+
hidden_size,
|
324 |
+
frame_resolution,
|
325 |
+
n_head,
|
326 |
+
frame_num,
|
327 |
+
init_method=unscaled_init_method(0.02),
|
328 |
+
output_layer_init_method=unscaled_init_method(0.02),
|
329 |
+
**kwargs,
|
330 |
+
):
|
331 |
+
super(FullAttentionMixin, self).__init__()
|
332 |
+
self.num_layers = num_layers # replace attention in the LAST n layers
|
333 |
+
self.query_key_value = torch.nn.ModuleList(
|
334 |
+
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
|
335 |
+
gather_output=False,init_method=init_method)
|
336 |
+
for layer_id in range(num_layers)
|
337 |
+
])
|
338 |
+
self.dense = torch.nn.ModuleList(
|
339 |
+
[RowParallelLinear(
|
340 |
+
hidden_size,
|
341 |
+
hidden_size,
|
342 |
+
input_is_parallel=True,
|
343 |
+
init_method=output_layer_init_method,
|
344 |
+
bias=True,
|
345 |
+
module=self,
|
346 |
+
name="dense")
|
347 |
+
for layer_id in range(num_layers)
|
348 |
+
])
|
349 |
+
|
350 |
+
self.n_head = n_head
|
351 |
+
self.frame_resolution = frame_resolution
|
352 |
+
self.frame_len = frame_resolution * frame_resolution
|
353 |
+
|
354 |
+
self.attn_distribution = torch.nn.ParameterList([
|
355 |
+
torch.nn.Parameter(torch.zeros(hidden_size))
|
356 |
+
for _ in range(num_layers)
|
357 |
+
])
|
358 |
+
|
359 |
+
def reinit(self, *pre_mixins):
|
360 |
+
start_layer = len(self.transformer.layers) - self.num_layers
|
361 |
+
assert start_layer >= 0
|
362 |
+
for layer_id in range(self.num_layers):
|
363 |
+
old_attention = self.transformer.layers[start_layer + layer_id].attention
|
364 |
+
self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
|
365 |
+
self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
|
366 |
+
|
367 |
+
|
368 |
+
def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
|
369 |
+
# frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
|
370 |
+
assert stage == 1
|
371 |
+
|
372 |
+
b0, s1, h0 = frame_hidden_state.shape
|
373 |
+
h = h0 // self.n_head
|
374 |
+
frame_len = self.frame_resolution * self.frame_resolution
|
375 |
+
frame_num = s1 // frame_len
|
376 |
+
assert frame_num*frame_len == s1
|
377 |
+
|
378 |
+
if memkv_text is not None:
|
379 |
+
s0 = memkv_text.shape[-2]
|
380 |
+
k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
|
381 |
+
v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
|
382 |
+
qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
|
383 |
+
.permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
|
384 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
385 |
+
attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
|
386 |
+
attn = attn - 10000.0 * (1.0-torch.ones(b0, self.n_head, s1, s1, device=attn.device, dtype=attn.dtype).tril())
|
387 |
+
|
388 |
+
if memkv_text is None:
|
389 |
+
attn = F.softmax(attn, dim=-1)
|
390 |
+
if attn_dropout is not None:
|
391 |
+
with get_cuda_rng_tracker().fork():
|
392 |
+
attn = attn_dropout(attn)
|
393 |
+
context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
394 |
+
else:
|
395 |
+
attn_frame2text = torch.matmul(q / math.sqrt(h), k_text.transpose(-1, -2)) #[b0, s1, s0]
|
396 |
+
attn = torch.cat((attn, attn_frame2text), dim=-1)
|
397 |
+
attn = F.softmax(attn, dim=-1)
|
398 |
+
if attn_dropout is not None:
|
399 |
+
with get_cuda_rng_tracker().fork():
|
400 |
+
attn = attn_dropout(attn)
|
401 |
+
context_swin = (torch.matmul(attn[..., :-s0], v) + torch.matmul(attn[..., -s0:], v_text))\
|
402 |
+
.permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
403 |
+
|
404 |
+
# for mem
|
405 |
+
memk = k.permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
406 |
+
memv = v.permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
407 |
+
ret_mem = torch.cat((memk, memv), dim=-1)
|
408 |
+
|
409 |
+
return context_swin, ret_mem
|
410 |
+
|
411 |
+
def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
|
412 |
+
# pos: current token's pos
|
413 |
+
b0, sin, h0 = frame_hidden_state.shape
|
414 |
+
h = h0 // self.n_head
|
415 |
+
assert sin == 1
|
416 |
+
assert stage == 1
|
417 |
+
|
418 |
+
this_qkv = self.query_key_value[layer_id](frame_hidden_state)
|
419 |
+
thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
|
420 |
+
|
421 |
+
if memkv is not None:
|
422 |
+
used_k, used_v = memkv[..., :h0], memkv[..., h0:]
|
423 |
+
used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
|
424 |
+
used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
|
425 |
+
else:
|
426 |
+
used_k, used_v = thisk, thisv
|
427 |
+
|
428 |
+
if memkv_text is not None:
|
429 |
+
used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
|
430 |
+
used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
|
431 |
+
|
432 |
+
used_k = used_k.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
|
433 |
+
used_v = used_v.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
|
434 |
+
thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
|
435 |
+
attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
|
436 |
+
if memkv_text is not None:
|
437 |
+
attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
|
438 |
+
attn = F.softmax(attn, dim=-1)
|
439 |
+
|
440 |
+
context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
|
441 |
+
|
442 |
+
return context_swin, this_qkv[..., h0:]
|
443 |
+
|
444 |
+
|
445 |
+
def attention_localframe_and_text_NAR(q0, k0, v0, attention_mask,
|
446 |
+
n_head, text_len, frame_len, frame_num,
|
447 |
+
attention_dropout=None, log_text_attention_weights=0, stage=1, **kwargs):
|
448 |
+
b, s0, h0 = q0.shape
|
449 |
+
s1 = s0 - text_len
|
450 |
+
h = h0 // n_head
|
451 |
+
assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
|
452 |
+
# attention_mask.shape [4, b or 1, 1, text_len+frame_len, text_len+frame_len]
|
453 |
+
if stage == 2:
|
454 |
+
assert frame_num == 3
|
455 |
+
|
456 |
+
q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
457 |
+
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
458 |
+
k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
459 |
+
k0T = k0.transpose(-1, -2)
|
460 |
+
|
461 |
+
score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
|
462 |
+
score_any2text += log_text_attention_weights
|
463 |
+
score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask[..., :text_len, :text_len]) \
|
464 |
+
- 10000.0 * (1.0 - attention_mask[..., :text_len, :text_len])
|
465 |
+
# context for text
|
466 |
+
attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
|
467 |
+
if attention_dropout is not None:
|
468 |
+
with get_cuda_rng_tracker().fork():
|
469 |
+
attention_probs_text = attention_dropout(attention_probs_text)
|
470 |
+
context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
|
471 |
+
context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
|
472 |
+
|
473 |
+
if frame_num > 0:
|
474 |
+
score_any2text_part2 = score_any2text[..., text_len:, :]
|
475 |
+
|
476 |
+
# score: frame local
|
477 |
+
q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
|
478 |
+
v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
|
479 |
+
k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
|
480 |
+
score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
|
481 |
+
if stage == 1:
|
482 |
+
score_frame_local0 = torch.mul(score_frame_local0, attention_mask[..., text_len:, text_len:].unsqueeze(1)) \
|
483 |
+
- 10000.0 * (1.0 - attention_mask[..., text_len:, text_len:].unsqueeze(1))
|
484 |
+
|
485 |
+
# context for frame
|
486 |
+
score_frame_all = torch.cat((score_any2text_part2,
|
487 |
+
score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
|
488 |
+
attention_probs_frame = F.softmax(score_frame_all, dim=-1)
|
489 |
+
if attention_dropout is not None:
|
490 |
+
with get_cuda_rng_tracker().fork():
|
491 |
+
attention_probs_frame = attention_dropout(attention_probs_frame)
|
492 |
+
context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
|
493 |
+
context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
|
494 |
+
view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
|
495 |
+
|
496 |
+
context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
|
497 |
+
else:
|
498 |
+
context_frame = None
|
499 |
+
|
500 |
+
return context_text2text, context_frame
|
501 |
+
|
502 |
+
def attention_localframe_and_text_AR(q0, k0, v0, n_head, text_len, frame_len, frame_num,
|
503 |
+
attention_dropout=None, log_text_attention_weights=0, layer_id=None, limited_spatial_channel_mem=False, stage=1, **kwargs):
|
504 |
+
# limited_spatial_channel_mem=True means: mems in spatial channel is consisted of {mem_text, mem_current_frame}
|
505 |
+
b, s0, h0 = k0.shape
|
506 |
+
frame_num_before = (s0-text_len-1) // frame_len # frame_num == frame_num_before or frame_num == frame_num_before+1
|
507 |
+
h = h0 // n_head
|
508 |
+
assert q0.shape[1] == 1
|
509 |
+
assert v0.shape[1] == k0.shape[1]
|
510 |
+
|
511 |
+
q0 = q0.reshape(b, 1, n_head, h).permute(0, 2, 1, 3)
|
512 |
+
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
513 |
+
k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
|
514 |
+
|
515 |
+
if limited_spatial_channel_mem:
|
516 |
+
assert frame_num_before == 0
|
517 |
+
assert stage == 1 # not implemented for stage-2 yet
|
518 |
+
score = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
|
519 |
+
score[..., :text_len] += log_text_attention_weights
|
520 |
+
attention_probs_frame = F.softmax(score, dim=-1)
|
521 |
+
context_frame = torch.matmul(attention_probs_frame, v0).transpose(1, 2).reshape(b, 1, h0)
|
522 |
+
|
523 |
+
else:
|
524 |
+
score_token2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
|
525 |
+
score_token2text += log_text_attention_weights
|
526 |
+
score_frame_local0 = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., text_len+frame_num_before*frame_len:])
|
527 |
+
score_frame_all = torch.cat((score_token2text,
|
528 |
+
score_frame_local0), dim=-1)
|
529 |
+
attention_probs_frame = F.softmax(score_frame_all, dim=-1)
|
530 |
+
|
531 |
+
context_token2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
|
532 |
+
context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:], \
|
533 |
+
v0[:, :, text_len+frame_num_before*frame_len:, :])
|
534 |
+
context_frame = (context_token2text + context_frame_local0).transpose(1, 2).reshape(b, 1, h0)
|
535 |
+
|
536 |
+
return context_frame
|
537 |
+
|
538 |
+
|
539 |
+
class CogVideoCacheModel(BaseModel):
|
540 |
+
def __init__(self, args, transformer=None, parallel_output=True, window_size=None, cogvideo_stage=None):
|
541 |
+
super().__init__(args, transformer=transformer, parallel_output=parallel_output)
|
542 |
+
self.layout = args.layout # [64, 64+1024, 64+6*1024]
|
543 |
+
self.stage = cogvideo_stage if cogvideo_stage is not None else args.cogvideo_stage # 1 or 2
|
544 |
+
self.n_head = args.num_attention_heads
|
545 |
+
self.window_size = window_size if window_size is not None else args.window_size
|
546 |
+
|
547 |
+
frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
|
548 |
+
self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
|
549 |
+
args.additional_seqlen, args.hidden_size
|
550 |
+
))
|
551 |
+
|
552 |
+
if self.stage == 1:
|
553 |
+
self.add_mixin('attention_plus', FullAttentionMixin(
|
554 |
+
num_layers=args.num_layers,
|
555 |
+
hidden_size=args.hidden_size,
|
556 |
+
frame_resolution=frame_resolution,
|
557 |
+
n_head=args.num_attention_heads,
|
558 |
+
frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
|
559 |
+
))
|
560 |
+
else:
|
561 |
+
self.add_mixin('attention_plus', WindowAttentionMixin(
|
562 |
+
num_layers=args.num_layers,
|
563 |
+
hidden_size=args.hidden_size,
|
564 |
+
frame_resolution=frame_resolution,
|
565 |
+
window_size=self.window_size,
|
566 |
+
shift_size=self.window_size//2,
|
567 |
+
n_head=args.num_attention_heads,
|
568 |
+
frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
|
569 |
+
))
|
570 |
+
|
571 |
+
|
572 |
+
@classmethod
|
573 |
+
def add_model_specific_args(cls, parser):
|
574 |
+
group = parser.add_argument_group('VideoSwinLocalModel', 'video swin local model configurations')
|
575 |
+
group.add_argument("--layout", type=str, default='64, 464, 2064')
|
576 |
+
group.add_argument("--window-size", type=int, default=10) # 优先级在直接参数赋值之后
|
577 |
+
group.add_argument("--additional-seqlen", type=int, default=2000)
|
578 |
+
group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) # 优先级在直接参数赋值之后
|
579 |
+
return parser
|
580 |
+
|
581 |
+
def disable_untrainable_params(self):
|
582 |
+
pass
|
583 |
+
|
584 |
+
def position_embedding_forward(self, position_ids, **kw_args):
|
585 |
+
if position_ids.shape[-1] > 1:
|
586 |
+
if self.stage == 1:
|
587 |
+
if position_ids[0,-1] >= (512+400):
|
588 |
+
frame_num = position_ids.shape[-1] // 400
|
589 |
+
position_embeddings = torch.cat(
|
590 |
+
(
|
591 |
+
self.transformer.position_embeddings(position_ids[..., :-400*(frame_num-1)]),
|
592 |
+
self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -400*(frame_num-1):]-(512+400))
|
593 |
+
),
|
594 |
+
dim=-2
|
595 |
+
)
|
596 |
+
else:
|
597 |
+
position_embeddings = self.transformer.position_embeddings(position_ids)
|
598 |
+
else:
|
599 |
+
# given 3, interpolate 2
|
600 |
+
position_embeddings = torch.cat(
|
601 |
+
(
|
602 |
+
self.transformer.position_embeddings(position_ids[..., :-800]),
|
603 |
+
self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -800:]-(512+400))
|
604 |
+
),
|
605 |
+
dim=-2
|
606 |
+
)
|
607 |
+
else:
|
608 |
+
if position_ids[0, 0] >= (512+400):
|
609 |
+
position_embeddings = self.get_mixin('extra_position_embedding').position_embeddings(position_ids-(512+400))
|
610 |
+
else:
|
611 |
+
position_embeddings = self.transformer.position_embeddings(position_ids)
|
612 |
+
return position_embeddings
|
613 |
+
|
614 |
+
def attention_forward(self, hidden_states, mask, layer_id, mems=None, log_text_attention_weights=0, text_len=0, frame_len=0, counter=0, enforce_no_swin=False, limited_spatial_channel_mem=False, **kw_args):
|
615 |
+
attn_module = self.transformer.layers[layer_id].attention
|
616 |
+
hidden_size = hidden_states.shape[-1]
|
617 |
+
|
618 |
+
# base model qkv
|
619 |
+
if mems is None:
|
620 |
+
mixed_raw_layer = attn_module.query_key_value(hidden_states)
|
621 |
+
q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
|
622 |
+
assert (q0.shape[1]-text_len) % frame_len == 0
|
623 |
+
memkv0 = torch.cat((k0, v0), dim=-1)
|
624 |
+
context_text, context_frame_local_text = attention_localframe_and_text_NAR(
|
625 |
+
q0, k0, v0,
|
626 |
+
mask,
|
627 |
+
n_head=attn_module.num_attention_heads_per_partition,
|
628 |
+
text_len=text_len,
|
629 |
+
frame_len=frame_len,
|
630 |
+
frame_num=(q0.shape[1]-text_len)//frame_len,
|
631 |
+
log_text_attention_weights=log_text_attention_weights,
|
632 |
+
stage=self.stage
|
633 |
+
)
|
634 |
+
|
635 |
+
# change: self.swin_attend_to_text默认为True:
|
636 |
+
memkv1_text = self.get_mixin('attention_plus').query_key_value[layer_id](hidden_states[..., :text_len, :])[..., hidden_size:]
|
637 |
+
output_text = attn_module.dense(context_text)
|
638 |
+
|
639 |
+
if (q0.shape[1]-text_len)//frame_len > 0:
|
640 |
+
assert (q0.shape[1]-text_len) % frame_len == 0
|
641 |
+
context_frame_swin, memkv1_frame = self.get_mixin('attention_plus').attention_extra_NAR_inference(
|
642 |
+
hidden_states[:,text_len:], layer_id, memkv_text=memkv1_text, stage=self.stage)
|
643 |
+
if not enforce_no_swin:
|
644 |
+
attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
|
645 |
+
attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
|
646 |
+
output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
|
647 |
+
+torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
|
648 |
+
else:
|
649 |
+
output_frame = attn_module.dense(context_frame_local_text[..., :frame_len, :])
|
650 |
+
output = torch.cat((output_text, output_frame), dim=-2)
|
651 |
+
memkv1 = torch.cat((memkv1_text, memkv1_frame), dim=-2) if memkv1_text is not None else memkv1_frame
|
652 |
+
else:
|
653 |
+
output = output_text
|
654 |
+
memkv1 = memkv1_text
|
655 |
+
kw_args['output_this_layer']['mem_kv'] = (memkv0, memkv1)
|
656 |
+
|
657 |
+
|
658 |
+
else:
|
659 |
+
mixed_raw_layer = attn_module.query_key_value(hidden_states)
|
660 |
+
q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
|
661 |
+
new_memkv0 = torch.cat((k0, v0), dim=-1)
|
662 |
+
old_k0, old_v0 = mems[0][layer_id][..., :hidden_size], mems[0][layer_id][..., hidden_size:]
|
663 |
+
|
664 |
+
context_frame_local_text = attention_localframe_and_text_AR(
|
665 |
+
q0,
|
666 |
+
torch.cat((old_k0.expand(k0.shape[0], -1, -1), k0), dim=-2),
|
667 |
+
torch.cat((old_v0.expand(v0.shape[0], -1, -1), v0), dim=-2),
|
668 |
+
n_head=attn_module.num_attention_heads_per_partition,
|
669 |
+
text_len=text_len,
|
670 |
+
frame_len=frame_len,
|
671 |
+
frame_num=None,
|
672 |
+
log_text_attention_weights=log_text_attention_weights,
|
673 |
+
layer_id=layer_id,
|
674 |
+
limited_spatial_channel_mem=limited_spatial_channel_mem,
|
675 |
+
)
|
676 |
+
|
677 |
+
old_memkv1 = mems[1][layer_id] if mems[1] is not None else None
|
678 |
+
|
679 |
+
context_frame_swin, new_memkv1 = self.get_mixin('attention_plus').attention_extra_AR_inference(hidden_states,
|
680 |
+
old_memkv1[..., text_len:, :] if old_memkv1.shape[-2]>text_len else None,
|
681 |
+
counter-text_len,
|
682 |
+
layer_id,
|
683 |
+
memkv_text=old_memkv1[..., :text_len, :],
|
684 |
+
log_text_attention_weights=log_text_attention_weights)
|
685 |
+
if not enforce_no_swin:
|
686 |
+
attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
|
687 |
+
attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
|
688 |
+
output = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
|
689 |
+
+torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
|
690 |
+
else:
|
691 |
+
output = attn_module.dense(context_frame_local_text)
|
692 |
+
|
693 |
+
kw_args['output_this_layer']['mem_kv'] = (new_memkv0, new_memkv1)
|
694 |
+
|
695 |
+
return output
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_model.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : cogvideo_model.py
|
4 |
+
@Time : 2022/07/11 16:12:05
|
5 |
+
@Author : Wenyi Hong
|
6 |
+
@Version : 1.0
|
7 |
+
@Contact : [email protected]
|
8 |
+
'''
|
9 |
+
|
10 |
+
# here put the import lib
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
|
14 |
+
|
15 |
+
from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
|
16 |
+
from SwissArmyTransformer.model.transformer import unscaled_init_method
|
17 |
+
from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
|
20 |
+
import math
|
21 |
+
|
22 |
+
class PositionEmbeddingMixin(BaseMixin):
|
23 |
+
def __init__(self, additional_sequence_length, hidden_size,
|
24 |
+
init_method_std=0.02, reinit_slice=slice(512, 912),
|
25 |
+
):
|
26 |
+
super(PositionEmbeddingMixin, self).__init__()
|
27 |
+
self.reinit_slice = reinit_slice
|
28 |
+
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
|
29 |
+
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
|
30 |
+
|
31 |
+
def reinit(self, parent_model=None):
|
32 |
+
old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
|
33 |
+
old_len, hidden_size = old_weights.shape
|
34 |
+
assert hidden_size == self.position_embeddings.weight.shape[-1]
|
35 |
+
self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
|
36 |
+
|
37 |
+
def window_partition(x, window_size):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
x: (B, framenum, H, W, C)
|
41 |
+
window_size (int): window size
|
42 |
+
Returns:
|
43 |
+
windows: (num_windows*B, frame_num, window_size, window_size, C)
|
44 |
+
"""
|
45 |
+
B, framenum, H, W, C = x.shape
|
46 |
+
x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
|
47 |
+
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
|
48 |
+
return windows
|
49 |
+
|
50 |
+
def window_reverse(windows, window_size, H, W):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
windows: (num_windows*B, frame_num, window_size, window_size, C)
|
54 |
+
window_size (int): Window size
|
55 |
+
H (int): Height of image
|
56 |
+
W (int): Width of image
|
57 |
+
Returns:
|
58 |
+
x: (B, frame_num, H, W, C)
|
59 |
+
"""
|
60 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
61 |
+
framenum = windows.shape[1]
|
62 |
+
x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
|
63 |
+
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
|
64 |
+
return x
|
65 |
+
|
66 |
+
class WindowAttentionMixin(BaseMixin):
|
67 |
+
def __init__(self, num_layers,
|
68 |
+
hidden_size,
|
69 |
+
frame_resolution,
|
70 |
+
window_size,
|
71 |
+
shift_size,
|
72 |
+
n_head,
|
73 |
+
frame_num,
|
74 |
+
init_method=unscaled_init_method(0.02),
|
75 |
+
output_layer_init_method=unscaled_init_method(0.02),
|
76 |
+
):
|
77 |
+
super(WindowAttentionMixin, self).__init__()
|
78 |
+
self.num_layers = num_layers # replace attention in the LAST n layers
|
79 |
+
self.query_key_value = torch.nn.ModuleList(
|
80 |
+
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
|
81 |
+
gather_output=False,init_method=init_method)
|
82 |
+
for layer_id in range(num_layers)
|
83 |
+
])
|
84 |
+
self.dense = torch.nn.ModuleList(
|
85 |
+
[RowParallelLinear(
|
86 |
+
hidden_size,
|
87 |
+
hidden_size,
|
88 |
+
input_is_parallel=True,
|
89 |
+
init_method=output_layer_init_method,
|
90 |
+
bias=True,
|
91 |
+
module=self,
|
92 |
+
name="dense",
|
93 |
+
)
|
94 |
+
for layer_id in range(num_layers)
|
95 |
+
])
|
96 |
+
|
97 |
+
self.n_head = n_head
|
98 |
+
self.window_size = window_size
|
99 |
+
self.frame_resolution = frame_resolution
|
100 |
+
self.frame_len = frame_resolution * frame_resolution
|
101 |
+
assert frame_resolution % window_size == 0
|
102 |
+
assert 0 < shift_size < window_size
|
103 |
+
nW = (self.frame_resolution // self.window_size) ** 2
|
104 |
+
ws_squre = self.window_size * self.window_size
|
105 |
+
|
106 |
+
# odd non-shift, even shift
|
107 |
+
img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
|
108 |
+
h_slices = (slice(0, -shift_size),
|
109 |
+
slice(-shift_size, None))
|
110 |
+
w_slices = (slice(0, -shift_size),
|
111 |
+
slice(-shift_size, None))
|
112 |
+
cnt = 0
|
113 |
+
for h in h_slices:
|
114 |
+
for w in w_slices:
|
115 |
+
img_mask[:, :, h, w, :] = cnt
|
116 |
+
cnt += 1
|
117 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
|
118 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
119 |
+
sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
|
120 |
+
sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
|
121 |
+
attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
|
122 |
+
|
123 |
+
self.attn_mask_sequential = attn_mask.clone().tril()
|
124 |
+
self.causal_mask_sequential = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num).tril()
|
125 |
+
|
126 |
+
self.causal_mask_interp = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num)
|
127 |
+
self.attn_mask_interp = attn_mask.clone()
|
128 |
+
|
129 |
+
# bi-dir
|
130 |
+
for bi_idx in range(0, frame_num, 2):
|
131 |
+
for uni_idx in range(1, frame_num, 2):
|
132 |
+
self.attn_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
|
133 |
+
self.causal_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
|
134 |
+
# uni-dir
|
135 |
+
for uni_idx in range(1, frame_num, 2):
|
136 |
+
self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
|
137 |
+
self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
|
138 |
+
for uni_idx2 in range(uni_idx+2, frame_num, 2):
|
139 |
+
self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
|
140 |
+
self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
|
141 |
+
|
142 |
+
# expand dim
|
143 |
+
self.attn_mask_sequential = self.attn_mask_sequential[None, None, :, None]
|
144 |
+
self.attn_mask_interp = self.attn_mask_interp[None, None, :, None]
|
145 |
+
self.causal_mask_sequential = self.causal_mask_sequential[None, None, :, None]
|
146 |
+
self.causal_mask_interp = self.causal_mask_interp[None, None, :, None]
|
147 |
+
|
148 |
+
self.shift_sizes = [0, shift_size]
|
149 |
+
# self.register_buffer("attn_mask", attn_mask)
|
150 |
+
# self.register_buffer("causal_mask", causal_mask)
|
151 |
+
self.mask_initialized = False
|
152 |
+
|
153 |
+
self.attn_distribution = torch.nn.ParameterList([
|
154 |
+
torch.nn.Parameter(torch.zeros(hidden_size))
|
155 |
+
for _ in range(num_layers)
|
156 |
+
])
|
157 |
+
|
158 |
+
def reinit(self, *pre_mixins):
|
159 |
+
start_layer = len(self.transformer.layers) - self.num_layers
|
160 |
+
assert start_layer >= 0
|
161 |
+
for layer_id in range(self.num_layers):
|
162 |
+
old_attention = self.transformer.layers[start_layer + layer_id].attention
|
163 |
+
self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
|
164 |
+
self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
|
165 |
+
|
166 |
+
def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
|
167 |
+
text_attn_mask=None, mode_sequential=True):
|
168 |
+
# pb relax
|
169 |
+
swin_pb_relax = True
|
170 |
+
alpha = 16
|
171 |
+
|
172 |
+
# frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
|
173 |
+
if not self.mask_initialized:
|
174 |
+
self.attn_mask_sequential = self.attn_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
175 |
+
self.causal_mask_sequential = self.causal_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
176 |
+
self.attn_mask_interp = self.attn_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
177 |
+
self.causal_mask_interp = self.causal_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
178 |
+
self.mask_initialized = True
|
179 |
+
b0, s1, h0 = frame_hidden_state.shape
|
180 |
+
h = h0 // self.n_head
|
181 |
+
frame_len = self.frame_resolution * self.frame_resolution
|
182 |
+
frame_num = s1 // frame_len
|
183 |
+
assert frame_num*frame_len == s1
|
184 |
+
wind_square = self.window_size * self.window_size
|
185 |
+
nW = frame_len // wind_square
|
186 |
+
bswin = b0 * nW
|
187 |
+
|
188 |
+
causal_mask = self.causal_mask_sequential if mode_sequential else self.causal_mask_interp
|
189 |
+
attn_mask = self.attn_mask_sequential if mode_sequential else self.attn_mask_interp
|
190 |
+
if text_hidden_state is not None:
|
191 |
+
s0 = text_hidden_state.shape[1]
|
192 |
+
qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
|
193 |
+
q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
|
194 |
+
|
195 |
+
# shift
|
196 |
+
frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
|
197 |
+
if self.shift_sizes[layer_id%2] > 0:
|
198 |
+
frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
|
199 |
+
# window partition
|
200 |
+
frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
|
201 |
+
qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
|
202 |
+
.permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
|
203 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
204 |
+
|
205 |
+
# pb-relax
|
206 |
+
if swin_pb_relax:
|
207 |
+
attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
|
208 |
+
else:
|
209 |
+
attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
|
210 |
+
|
211 |
+
if self.shift_sizes[layer_id%2] > 0:
|
212 |
+
# attn = attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square) + self.attn_mask.unsqueeze(1).unsqueeze(0)
|
213 |
+
attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), attn_mask)\
|
214 |
+
- 10000.0 * (1.0 - attn_mask)
|
215 |
+
attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
|
216 |
+
else:
|
217 |
+
attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), causal_mask)\
|
218 |
+
- 10000.0 * (1.0 - causal_mask)
|
219 |
+
attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
|
220 |
+
if swin_pb_relax:
|
221 |
+
swin_pb_relax_const = torch.max(attn.reshape(bswin, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
|
222 |
+
attn = (attn - swin_pb_relax_const)*alpha
|
223 |
+
|
224 |
+
if text_hidden_state is None:
|
225 |
+
attn = F.softmax(attn, dim=-1)
|
226 |
+
if attn_dropout is not None:
|
227 |
+
with get_cuda_rng_tracker().fork():
|
228 |
+
attn = attn_dropout(attn)
|
229 |
+
context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
230 |
+
else:
|
231 |
+
assert text_attn_mask is not None
|
232 |
+
text_attn_mask = text_attn_mask.unsqueeze(2).unsqueeze(2)
|
233 |
+
# pb-relax
|
234 |
+
if swin_pb_relax:
|
235 |
+
attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / (math.sqrt(h)*alpha), k_text.unsqueeze(1).transpose(-1, -2))
|
236 |
+
attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, -1, self.n_head, 1, 1))*alpha
|
237 |
+
else:
|
238 |
+
attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
|
239 |
+
|
240 |
+
attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
|
241 |
+
attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
|
242 |
+
attn = torch.cat((attn, attn_frame2text), dim=-1)
|
243 |
+
attn = F.softmax(attn, dim=-1)
|
244 |
+
|
245 |
+
if attn_dropout is not None:
|
246 |
+
with get_cuda_rng_tracker().fork():
|
247 |
+
attn = attn_dropout(attn)
|
248 |
+
|
249 |
+
context_swin = (torch.matmul(attn[..., :-s0], v) +
|
250 |
+
torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
|
251 |
+
.reshape(bswin, self.n_head, frame_num*wind_square, h))\
|
252 |
+
.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
253 |
+
|
254 |
+
context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
|
255 |
+
# reverse cycle shift
|
256 |
+
if self.shift_sizes[layer_id%2] > 0:
|
257 |
+
context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
|
258 |
+
context_swin = context_swin.reshape(b0, s1, h0)
|
259 |
+
|
260 |
+
return context_swin
|
261 |
+
|
262 |
+
|
263 |
+
class FullAttentionMixin(BaseMixin):
|
264 |
+
def __init__(self, num_layers,
|
265 |
+
hidden_size,
|
266 |
+
frame_resolution,
|
267 |
+
n_head,
|
268 |
+
frame_num,
|
269 |
+
init_method=unscaled_init_method(0.02),
|
270 |
+
output_layer_init_method=unscaled_init_method(0.02),
|
271 |
+
):
|
272 |
+
super(FullAttentionMixin, self).__init__()
|
273 |
+
self.num_layers = num_layers # replace attention in the LAST n layers
|
274 |
+
self.query_key_value = torch.nn.ModuleList(
|
275 |
+
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
|
276 |
+
gather_output=False,init_method=init_method)
|
277 |
+
for layer_id in range(num_layers)
|
278 |
+
])
|
279 |
+
self.dense = torch.nn.ModuleList(
|
280 |
+
[RowParallelLinear(
|
281 |
+
hidden_size,
|
282 |
+
hidden_size,
|
283 |
+
input_is_parallel=True,
|
284 |
+
init_method=output_layer_init_method,
|
285 |
+
bias=True,
|
286 |
+
module=self,
|
287 |
+
name="dense",)
|
288 |
+
for layer_id in range(num_layers)
|
289 |
+
])
|
290 |
+
|
291 |
+
self.n_head = n_head
|
292 |
+
self.frame_resolution = frame_resolution
|
293 |
+
self.frame_len = frame_resolution * frame_resolution
|
294 |
+
self.causal_mask = torch.ones(1, 1, self.frame_len*frame_num, self.frame_len*frame_num).tril()
|
295 |
+
|
296 |
+
self.mask_initialized = False
|
297 |
+
|
298 |
+
self.attn_distribution = torch.nn.ParameterList([
|
299 |
+
torch.nn.Parameter(torch.zeros(hidden_size))
|
300 |
+
for _ in range(num_layers)
|
301 |
+
])
|
302 |
+
|
303 |
+
def reinit(self, *pre_mixins):
|
304 |
+
start_layer = len(self.transformer.layers) - self.num_layers
|
305 |
+
assert start_layer >= 0
|
306 |
+
for layer_id in range(self.num_layers):
|
307 |
+
base_attention = self.transformer.layers[start_layer + layer_id].attention
|
308 |
+
self.query_key_value[layer_id].weight.data.copy_(base_attention.query_key_value.weight.data)
|
309 |
+
self.query_key_value[layer_id].bias.data.copy_(base_attention.query_key_value.bias.data)
|
310 |
+
|
311 |
+
def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
|
312 |
+
text_attn_mask=None, mode_sequential=False):
|
313 |
+
# pb relax
|
314 |
+
# frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
|
315 |
+
assert mode_sequential == True # only
|
316 |
+
swin_pb_relax = True
|
317 |
+
alpha = 16
|
318 |
+
|
319 |
+
if not self.mask_initialized:
|
320 |
+
self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
321 |
+
self.mask_initialized = True
|
322 |
+
b0, s1, h0 = frame_hidden_state.shape
|
323 |
+
h = h0 // self.n_head
|
324 |
+
frame_len = self.frame_resolution * self.frame_resolution
|
325 |
+
frame_num = s1 // frame_len
|
326 |
+
assert frame_num*frame_len == s1
|
327 |
+
|
328 |
+
qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
|
329 |
+
.permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
|
330 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
331 |
+
|
332 |
+
# frames-to-frames
|
333 |
+
if swin_pb_relax:
|
334 |
+
attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
|
335 |
+
else:
|
336 |
+
attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
|
337 |
+
attn = torch.mul(attn, self.causal_mask) - 10000.0 * (1.0 - self.causal_mask)
|
338 |
+
if swin_pb_relax:
|
339 |
+
swin_pb_relax_const = torch.max(attn.reshape(b0, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
|
340 |
+
attn = (attn - swin_pb_relax_const)*alpha
|
341 |
+
|
342 |
+
if text_hidden_state is None:
|
343 |
+
attn = F.softmax(attn, dim=-1)
|
344 |
+
if attn_dropout is not None:
|
345 |
+
with get_cuda_rng_tracker().fork():
|
346 |
+
attn = attn_dropout(attn)
|
347 |
+
context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
348 |
+
else:
|
349 |
+
# frame-to-text
|
350 |
+
assert text_attn_mask is not None
|
351 |
+
s0 = text_hidden_state.shape[1]
|
352 |
+
qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
|
353 |
+
q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
|
354 |
+
text_attn_mask = text_attn_mask.unsqueeze(2)
|
355 |
+
if swin_pb_relax:
|
356 |
+
attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / (math.sqrt(h)*alpha), k_text.transpose(-1, -2))
|
357 |
+
attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, self.n_head, 1, 1))*alpha
|
358 |
+
else:
|
359 |
+
attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / math.sqrt(h), k_text.transpose(-1, -2))
|
360 |
+
attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
|
361 |
+
attn_frame2text = attn_frame2text.reshape(b0, self.n_head, s1, s0)
|
362 |
+
|
363 |
+
attn = torch.cat((attn, attn_frame2text), dim=-1)
|
364 |
+
attn = F.softmax(attn, dim=-1)
|
365 |
+
|
366 |
+
if attn_dropout is not None:
|
367 |
+
with get_cuda_rng_tracker().fork():
|
368 |
+
attn = attn_dropout(attn)
|
369 |
+
|
370 |
+
context_frame = (torch.matmul(attn[..., :-s0], v) +
|
371 |
+
torch.matmul(attn[..., -s0:].reshape(b0, self.n_head,s1, s0), v_text))\
|
372 |
+
.permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
373 |
+
|
374 |
+
return context_frame
|
375 |
+
|
376 |
+
|
377 |
+
def attention_localframe_and_text(q0, k0, v0, attention_mask_totxt, attention_mask_local,
|
378 |
+
n_head, text_len, frame_len, frame_num, attention_dropout=None, layer_id=0, **kwargs):
|
379 |
+
b, s0, h0 = q0.shape
|
380 |
+
s1 = s0 - text_len
|
381 |
+
h = h0 // n_head
|
382 |
+
assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
|
383 |
+
# attention_mask_totxt [b, 1, 1, text_len]
|
384 |
+
# attention_mask_local [1, 1, frame_num, frame_len, frame_len]
|
385 |
+
# attention_mask: [1, 1, text_len+frame_len, text_len+frame_len]
|
386 |
+
|
387 |
+
q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
388 |
+
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
389 |
+
k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
390 |
+
k0T = k0.transpose(-1, -2)
|
391 |
+
|
392 |
+
# score: any2text
|
393 |
+
score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
|
394 |
+
score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask_totxt) \
|
395 |
+
- 10000.0 * (1.0 - attention_mask_totxt)
|
396 |
+
score_any2text_part2 = torch.mul(score_any2text[..., text_len:, :], attention_mask_totxt) - \
|
397 |
+
10000.0 * (1.0 - attention_mask_totxt)
|
398 |
+
|
399 |
+
# score: frame local
|
400 |
+
q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
|
401 |
+
v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
|
402 |
+
k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
|
403 |
+
score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
|
404 |
+
score_frame_local0 = torch.mul(score_frame_local0, attention_mask_local) \
|
405 |
+
- 10000.0 * (1.0 - attention_mask_local)
|
406 |
+
|
407 |
+
# context for frame
|
408 |
+
score_frame_all = torch.cat((score_any2text_part2,
|
409 |
+
score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
|
410 |
+
attention_probs_frame = F.softmax(score_frame_all, dim=-1)
|
411 |
+
|
412 |
+
if attention_dropout is not None:
|
413 |
+
with get_cuda_rng_tracker().fork():
|
414 |
+
attention_probs_frame = attention_dropout(attention_probs_frame)
|
415 |
+
|
416 |
+
context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
|
417 |
+
context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
|
418 |
+
view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
|
419 |
+
context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
|
420 |
+
|
421 |
+
# context for text
|
422 |
+
attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
|
423 |
+
if attention_dropout is not None:
|
424 |
+
with get_cuda_rng_tracker().fork():
|
425 |
+
attention_probs_text = attention_dropout(attention_probs_text)
|
426 |
+
context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
|
427 |
+
context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
|
428 |
+
|
429 |
+
return context_text2text, context_frame
|
430 |
+
|
431 |
+
|
432 |
+
class CogVideoModel(BaseModel):
|
433 |
+
def __init__(self, args, transformer=None, parallel_output=True):
|
434 |
+
super().__init__(args, transformer=transformer, parallel_output=parallel_output)
|
435 |
+
self.stage = args.cogvideo_stage # 1 or 2
|
436 |
+
self.mode_sequential = True if self.stage==1 else False
|
437 |
+
self.layout = args.layout # [64, 64+400, 64+5*400]
|
438 |
+
self.n_head = args.num_attention_heads
|
439 |
+
frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
|
440 |
+
frame_num = (args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0])
|
441 |
+
frame_len = self.layout[1]-self.layout[0]
|
442 |
+
|
443 |
+
self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
|
444 |
+
args.additional_seqlen, args.hidden_size
|
445 |
+
))
|
446 |
+
|
447 |
+
if args.window_size == -1:
|
448 |
+
# full attention
|
449 |
+
assert self.stage == 1
|
450 |
+
self.add_mixin('attention_plus', FullAttentionMixin(
|
451 |
+
num_layers=args.num_layers,
|
452 |
+
hidden_size=args.hidden_size,
|
453 |
+
frame_resolution=frame_resolution,
|
454 |
+
n_head=args.num_attention_heads,
|
455 |
+
frame_num=frame_num,
|
456 |
+
))
|
457 |
+
else:
|
458 |
+
self.add_mixin('attention_plus', WindowAttentionMixin(
|
459 |
+
num_layers=args.num_layers,
|
460 |
+
hidden_size=args.hidden_size,
|
461 |
+
frame_resolution=frame_resolution,
|
462 |
+
window_size=args.window_size,
|
463 |
+
shift_size=args.window_size//2,
|
464 |
+
n_head=args.num_attention_heads,
|
465 |
+
frame_num=frame_num,
|
466 |
+
))
|
467 |
+
# attention_mask_local
|
468 |
+
self.attention_mask_local_sequential = torch.ones(1, 1, frame_num, frame_len, frame_len).tril().unsqueeze(0)
|
469 |
+
self.attention_mask_local_interp = torch.ones(1, 1, frame_num, frame_len, frame_len)
|
470 |
+
|
471 |
+
for idx in range(1, frame_num, 2):
|
472 |
+
self.attention_mask_local_interp[:, :, idx:idx+1].tril_()
|
473 |
+
self.attention_mask_local_interp = self.attention_mask_local_interp.unsqueeze(0)
|
474 |
+
self.mask_initialized = False
|
475 |
+
|
476 |
+
@classmethod
|
477 |
+
def add_model_specific_args(cls, parser):
|
478 |
+
group = parser.add_argument_group('CogVideoModel', 'CogVideo model configurations')
|
479 |
+
group.add_argument("--layout", type=str, default='64, 464, 2064', help='text_len, textlen+frame_len, textlen+frame_len*frame_num')
|
480 |
+
group.add_argument("--window-size", type=int, default=10, help="swin attention's window size in temperal channel, -1 represents full attention")
|
481 |
+
group.add_argument("--additional-seqlen", type=int, default=2000)
|
482 |
+
group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2])
|
483 |
+
return parser
|
484 |
+
|
485 |
+
def disable_untrainable_params(self):
|
486 |
+
self.transformer.requires_grad_(False)
|
487 |
+
|
488 |
+
def position_embedding_forward(self, position_ids, **kw_args):
|
489 |
+
position = position_ids[..., :(64+400)]
|
490 |
+
position_plus = position_ids[..., (64+400):]
|
491 |
+
position_embeddings = torch.cat(
|
492 |
+
(
|
493 |
+
self.transformer.position_embeddings(position),
|
494 |
+
self.get_mixin('extra_position_embedding').position_embeddings(position_plus-(512+400))
|
495 |
+
),
|
496 |
+
dim=-2
|
497 |
+
)
|
498 |
+
return position_embeddings
|
499 |
+
|
500 |
+
def attention_forward(self, hidden_states, mask, layer_id, **kw_args):
|
501 |
+
# mask.shape=[bs, 1, 1, 64]
|
502 |
+
if not self.mask_initialized:
|
503 |
+
self.attention_mask_local_sequential = self.attention_mask_local_sequential.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
504 |
+
self.attention_mask_local_interp = self.attention_mask_local_interp.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
505 |
+
self.mask_initialized = True
|
506 |
+
|
507 |
+
attn_module = self.transformer.layers[layer_id].attention
|
508 |
+
hidden_size = hidden_states.shape[-1]
|
509 |
+
bs = hidden_states.shape[0]
|
510 |
+
|
511 |
+
# base model qkv
|
512 |
+
mixed_raw_layer = attn_module.query_key_value(hidden_states)
|
513 |
+
q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
|
514 |
+
dropout_fn = self.transformer.layers[layer_id].attention.attention_dropout if self.training else None
|
515 |
+
|
516 |
+
attention_mask_local = self.attention_mask_local_sequential if self.mode_sequential else self.attention_mask_local_interp
|
517 |
+
context_text, context_frame_local_text = attention_localframe_and_text(
|
518 |
+
q0, k0, v0,
|
519 |
+
attention_mask_totxt=mask,
|
520 |
+
attention_mask_local=attention_mask_local,
|
521 |
+
n_head=attn_module.num_attention_heads_per_partition,
|
522 |
+
text_len=self.layout[0],
|
523 |
+
frame_len=self.layout[1]-self.layout[0],
|
524 |
+
frame_num=(self.layout[2]-self.layout[0])//(self.layout[1]-self.layout[0]),
|
525 |
+
attention_dropout=dropout_fn,
|
526 |
+
layer_id=layer_id,
|
527 |
+
)
|
528 |
+
|
529 |
+
context_frame_swin = self.get_mixin('attention_plus').attention_extra(
|
530 |
+
hidden_states[:, self.layout[0]:], layer_id, dropout_fn,
|
531 |
+
text_hidden_state=hidden_states[:, :self.layout[0]],
|
532 |
+
text_attn_mask=mask[..., 0, :],
|
533 |
+
mode_sequential=self.mode_sequential)
|
534 |
+
|
535 |
+
attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
|
536 |
+
attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
|
537 |
+
|
538 |
+
output_text = attn_module.dense(context_text)
|
539 |
+
output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
|
540 |
+
+torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
|
541 |
+
output = torch.cat((output_text, output_frame), dim=-2)
|
542 |
+
|
543 |
+
return output
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/pretrain_cogvideo.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : pretrain_cogvideo.py
|
4 |
+
@Time : 2021/10/06 00:58:32
|
5 |
+
@Author : Wenyi Hong
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
import torch
|
15 |
+
import argparse
|
16 |
+
import numpy as np
|
17 |
+
from videogen_hub.depend.icetk import icetk as tokenizer
|
18 |
+
tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
|
19 |
+
|
20 |
+
from models.cogvideo_model import CogVideoModel
|
21 |
+
from SwissArmyTransformer import mpu, get_args
|
22 |
+
from SwissArmyTransformer.training.deepspeed_training import training_main
|
23 |
+
from SwissArmyTransformer.data_utils import BinaryDataset
|
24 |
+
|
25 |
+
def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None):
|
26 |
+
# Extract batch size and sequence length.
|
27 |
+
batch_size, seq_length = data.size()
|
28 |
+
assert attention_mask_totxt is not None
|
29 |
+
layout = args.layout
|
30 |
+
assert seq_length == layout[-1]
|
31 |
+
n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long()
|
32 |
+
frame_len = layout[1]-layout[0]
|
33 |
+
position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long,
|
34 |
+
device=data.device)
|
35 |
+
for i in range(batch_size):
|
36 |
+
torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]],
|
37 |
+
dtype=torch.long, device=data.device)
|
38 |
+
torch.arange(512, 512+layout[2]-layout[0],
|
39 |
+
out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device)
|
40 |
+
return position_ids
|
41 |
+
|
42 |
+
|
43 |
+
def get_batch(data_iterator, args, timers):
|
44 |
+
# Items and their type.
|
45 |
+
keys = ['text', 'loss_mask', 'attention_mask_totxt']
|
46 |
+
datatype = torch.int64
|
47 |
+
|
48 |
+
# Broadcast data.
|
49 |
+
timers('data loader').start()
|
50 |
+
if data_iterator is not None:
|
51 |
+
data = next(data_iterator)
|
52 |
+
else:
|
53 |
+
data = None
|
54 |
+
timers('data loader').stop()
|
55 |
+
|
56 |
+
data_b = mpu.broadcast_data(keys, data, datatype)
|
57 |
+
# Unpack.
|
58 |
+
tokens_ = data_b['text'].long()
|
59 |
+
loss_mask = data_b['loss_mask'].float()
|
60 |
+
attention_mask_totxt = data_b['attention_mask_totxt'].float()
|
61 |
+
|
62 |
+
labels = tokens_[:, 1:].clone().contiguous()
|
63 |
+
loss_mask = loss_mask[:, 1:].contiguous()
|
64 |
+
tokens = tokens_[:, :-1].clone().contiguous()
|
65 |
+
|
66 |
+
for idx in range(args.layout[0], args.layout[2], 400):
|
67 |
+
tokens[:, idx] = tokenizer['<start_of_image>']
|
68 |
+
# Get the masks and postition ids.
|
69 |
+
position_ids = get_masks_and_position_ids_video(
|
70 |
+
tokens,
|
71 |
+
attention_mask_totxt=attention_mask_totxt,
|
72 |
+
args=args
|
73 |
+
)
|
74 |
+
attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1)
|
75 |
+
# Convert
|
76 |
+
if args.fp16:
|
77 |
+
attention_mask_totxt = attention_mask_totxt.half()
|
78 |
+
return tokens, labels, loss_mask, attention_mask_totxt, position_ids
|
79 |
+
|
80 |
+
|
81 |
+
def forward_step(data_iterator, model, args, timers):
|
82 |
+
"""Forward step."""
|
83 |
+
|
84 |
+
# Get the batch.
|
85 |
+
timers('batch generator').start()
|
86 |
+
tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch(
|
87 |
+
data_iterator, args, timers)
|
88 |
+
timers('batch generator').stop()
|
89 |
+
|
90 |
+
# Forward model.
|
91 |
+
logits, *mems = model(tokens, position_ids, attention_mask_totxt)
|
92 |
+
# ======= hyper params =======#
|
93 |
+
perframe_len = 400
|
94 |
+
text_len=64
|
95 |
+
frame_num = 5
|
96 |
+
logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous()
|
97 |
+
losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:])
|
98 |
+
# scaling loss mask
|
99 |
+
loss_mask = loss_mask[:, text_len:].reshape(-1)
|
100 |
+
|
101 |
+
losses_1d = losses.reshape(-1) * loss_mask
|
102 |
+
loss = torch.sum(losses_1d) / loss_mask.sum()
|
103 |
+
# ===================== Log partial losses ======================== #
|
104 |
+
log_loss_dict = {}
|
105 |
+
bs = losses.shape[0]
|
106 |
+
|
107 |
+
if args.cogvideo_stage == 1:
|
108 |
+
for i in range(frame_num):
|
109 |
+
log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
|
110 |
+
else:
|
111 |
+
for i in range(1, frame_num-1):
|
112 |
+
log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
|
113 |
+
|
114 |
+
# ===================== END OF BLOCK ======================= #
|
115 |
+
return loss, log_loss_dict
|
116 |
+
|
117 |
+
|
118 |
+
def create_dataset_function(path, args):
|
119 |
+
dataset_layout = [64, 464, 2064]
|
120 |
+
input_layout = [64, 464, 2064]
|
121 |
+
# frame_num = 6
|
122 |
+
# frame_interval = 2 # DEBUG!!!
|
123 |
+
def process_fn(row):
|
124 |
+
row = row.astype(np.int64)
|
125 |
+
text = row[:dataset_layout[0]]
|
126 |
+
frames = row[dataset_layout[0]:]
|
127 |
+
|
128 |
+
if text[0] == tokenizer['<pad>']:
|
129 |
+
text = text[1:] # due to our way of data processing
|
130 |
+
if args.cogvideo_stage == 1:
|
131 |
+
text, loss_mask, frames = make_text_video_generation(text, frames)
|
132 |
+
else:
|
133 |
+
text, loss_mask, frames = mask_video_frame_interpolation(text, frames)
|
134 |
+
|
135 |
+
n_pad = input_layout[0] - len(text)
|
136 |
+
parts = [
|
137 |
+
np.array([tokenizer['<pad>']] * n_pad, dtype=np.int64),
|
138 |
+
text,
|
139 |
+
np.array([tokenizer['<start_of_image>']], dtype=np.int64),
|
140 |
+
frames,
|
141 |
+
]
|
142 |
+
ret = np.concatenate(parts, axis=0)
|
143 |
+
|
144 |
+
attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad))
|
145 |
+
return {'text': ret,
|
146 |
+
'loss_mask': loss_mask,
|
147 |
+
'attention_mask_totxt': attention_mask_totxt,
|
148 |
+
}
|
149 |
+
return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1])
|
150 |
+
|
151 |
+
def make_text_video_generation(text, frames):
|
152 |
+
input_layout = [64, 464, 2064]
|
153 |
+
text = text[text!= tokenizer['<pad>']][:input_layout[0]] # dataset format: 1.0秒<n>{text}<pad><pad> ...
|
154 |
+
loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位
|
155 |
+
return text, loss_mask, frames
|
156 |
+
|
157 |
+
def mask_video_frame_interpolation(text, frames):
|
158 |
+
input_layout = [64, 464, 2064]
|
159 |
+
frame_len = input_layout[1]-input_layout[0]
|
160 |
+
# text format: <pad> 1.0秒 <n> {text} <pad> <pad>
|
161 |
+
text = text[text!= tokenizer['<pad>']][:input_layout[0]]
|
162 |
+
loss_mask = np.array([0] * (input_layout[1]+1)
|
163 |
+
+ [1] * (input_layout[1]-input_layout[0])
|
164 |
+
+ [0] * (input_layout[1]-input_layout[0])
|
165 |
+
+ [1] * (input_layout[1]-input_layout[0])
|
166 |
+
+ [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位
|
167 |
+
|
168 |
+
return text, loss_mask, frames
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == '__main__':
|
173 |
+
py_parser = argparse.ArgumentParser(add_help=False)
|
174 |
+
py_parser.add_argument('--txt-loss-scale', type=float, default=1)
|
175 |
+
CogVideoModel.add_model_specific_args(py_parser)
|
176 |
+
|
177 |
+
known, args_list = py_parser.parse_known_args()
|
178 |
+
|
179 |
+
args = get_args(args_list)
|
180 |
+
args = argparse.Namespace(**vars(args), **vars(known))
|
181 |
+
|
182 |
+
args.layout = [int(x) for x in args.layout.split(',')]
|
183 |
+
|
184 |
+
training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SwissArmyTransformer==0.2.9
|
2 |
+
icetk
|
3 |
+
gifmaker
|
4 |
+
torchvision
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : __init__.py
|
4 |
+
@Time : 2022/03/02 13:57:09
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
|
15 |
+
from .direct_sr import DirectSuperResolution
|
16 |
+
from .iterative_sr import IterativeSuperResolution
|
17 |
+
from .sr_group import SRGroup
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/cluster_label2.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b87880fdbe89670f12844377b9cf97a9733b1f54e3a9b73cbb9835084c4e02ec
|
3 |
+
size 160128
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/direct_sr.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : direct_sr.py
|
4 |
+
@Time : 2022/03/02 13:58:11
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
import torch
|
15 |
+
|
16 |
+
# -*- encoding: utf-8 -*-
|
17 |
+
'''
|
18 |
+
@File : inference_cogview2.py
|
19 |
+
@Time : 2021/10/10 16:31:34
|
20 |
+
@Author : Ming Ding
|
21 |
+
@Contact : [email protected]
|
22 |
+
'''
|
23 |
+
|
24 |
+
# here put the import lib
|
25 |
+
import os
|
26 |
+
import sys
|
27 |
+
import math
|
28 |
+
import random
|
29 |
+
from PIL import ImageEnhance, Image
|
30 |
+
|
31 |
+
import torch
|
32 |
+
import argparse
|
33 |
+
from torchvision import transforms
|
34 |
+
|
35 |
+
from SwissArmyTransformer import get_args
|
36 |
+
from SwissArmyTransformer.training.model_io import load_checkpoint
|
37 |
+
from .dsr_sampling import filling_sequence_dsr, IterativeEntfilterStrategy
|
38 |
+
from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
|
39 |
+
|
40 |
+
from .dsr_model import DsrModel
|
41 |
+
|
42 |
+
from videogen_hub.depend.icetk import icetk as tokenizer
|
43 |
+
|
44 |
+
class DirectSuperResolution:
|
45 |
+
def __init__(self, args, path, max_bz=4, topk=6, onCUDA=False):
|
46 |
+
args.load = path
|
47 |
+
args.kernel_size = 5
|
48 |
+
args.kernel_size2 = 5
|
49 |
+
args.new_sequence_length = 4624
|
50 |
+
args.layout = [96,496,4096]
|
51 |
+
|
52 |
+
model = DsrModel(args)
|
53 |
+
if args.fp16:
|
54 |
+
model = model.half()
|
55 |
+
|
56 |
+
load_checkpoint(model, args) # on cpu
|
57 |
+
model.eval()
|
58 |
+
self.model = model
|
59 |
+
self.onCUDA = onCUDA
|
60 |
+
if onCUDA:
|
61 |
+
self.model = self.model.cuda()
|
62 |
+
|
63 |
+
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
|
64 |
+
|
65 |
+
self.strategy = IterativeEntfilterStrategy(invalid_slices,
|
66 |
+
temperature=1.0, topk=topk) # temperature not used # Temperature Freezed Here!!
|
67 |
+
self.max_bz = max_bz
|
68 |
+
|
69 |
+
def __call__(self, text_tokens, image_tokens, enhance=False):
|
70 |
+
if len(text_tokens.shape) == 1:
|
71 |
+
text_tokens.unsqueeze_(0)
|
72 |
+
if len(image_tokens.shape) == 1:
|
73 |
+
image_tokens.unsqueeze_(0)
|
74 |
+
# ===================== Debug ======================== #
|
75 |
+
# new_image_tokens = []
|
76 |
+
# for small_img in image_tokens:
|
77 |
+
# decoded = tokenizer.decode(image_ids=small_img)
|
78 |
+
# decoded = torch.nn.functional.interpolate(decoded, size=(480, 480)).squeeze(0)
|
79 |
+
# ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
80 |
+
# image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
|
81 |
+
# small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
|
82 |
+
# new_image_tokens.append(small_img2)
|
83 |
+
# image_tokens = torch.stack(new_image_tokens)
|
84 |
+
# return image_tokens
|
85 |
+
# ===================== END OF BLOCK ======================= #
|
86 |
+
if enhance:
|
87 |
+
new_image_tokens = []
|
88 |
+
for small_img in image_tokens:
|
89 |
+
decoded = tokenizer.decode(image_ids=small_img).squeeze(0)
|
90 |
+
ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
91 |
+
image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
|
92 |
+
small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.), image_size=160).view(-1)
|
93 |
+
new_image_tokens.append(small_img2)
|
94 |
+
image_tokens = torch.stack(new_image_tokens)
|
95 |
+
|
96 |
+
seq = torch.cat((text_tokens,image_tokens), dim=1)
|
97 |
+
seq1 = torch.tensor([tokenizer['<start_of_image>']]*3601, device=image_tokens.device).unsqueeze(0).expand(text_tokens.shape[0], -1)
|
98 |
+
if not self.onCUDA:
|
99 |
+
print('Converting Dsr model...')
|
100 |
+
model = self.model.cuda()
|
101 |
+
else:
|
102 |
+
model = self.model
|
103 |
+
print('Direct super-resolution...')
|
104 |
+
output_list = []
|
105 |
+
for tim in range(max((text_tokens.shape[0]+self.max_bz-1) // self.max_bz, 1)):
|
106 |
+
output1 = filling_sequence_dsr(model,
|
107 |
+
seq[tim*self.max_bz:(tim+1)*self.max_bz],
|
108 |
+
seq1[tim*self.max_bz:(tim+1)*self.max_bz],
|
109 |
+
warmup_steps=1, block_hw=(1, 0),
|
110 |
+
strategy=self.strategy
|
111 |
+
)
|
112 |
+
output_list.extend(output1[1:])
|
113 |
+
if not self.onCUDA:
|
114 |
+
print('Moving back Dsr to cpu...')
|
115 |
+
model = model.cpu()
|
116 |
+
torch.cuda.empty_cache()
|
117 |
+
return torch.cat(output_list, dim=0)
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_model.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : cuda2d_model.py
|
4 |
+
@Time : 2021/10/02 01:36:32
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
|
18 |
+
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
|
19 |
+
|
20 |
+
from SwissArmyTransformer.model.transformer import split_tensor_along_last_dim, unscaled_init_method
|
21 |
+
from SwissArmyTransformer.mpu.utils import sqrt
|
22 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
|
23 |
+
from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
|
24 |
+
|
25 |
+
class PositionEmbeddingMixin(BaseMixin):
|
26 |
+
def __init__(self, additional_sequence_length, hidden_size,
|
27 |
+
init_method_std=0.02, reinit_slice=slice(512, 512+400)
|
28 |
+
):
|
29 |
+
super(PositionEmbeddingMixin, self).__init__()
|
30 |
+
self.reinit_slice = reinit_slice
|
31 |
+
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
|
32 |
+
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
|
33 |
+
|
34 |
+
def reinit(self, parent_model=None):
|
35 |
+
old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
|
36 |
+
old_len, hidden_size = old_weights.shape
|
37 |
+
assert hidden_size == self.position_embeddings.weight.shape[-1]
|
38 |
+
old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
|
39 |
+
assert new_edge % old_edge == 0
|
40 |
+
self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
|
41 |
+
# self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
|
42 |
+
|
43 |
+
|
44 |
+
class AttentionMixin(BaseMixin):
|
45 |
+
def __init__(self, num_layers,
|
46 |
+
hidden_size,
|
47 |
+
init_method=unscaled_init_method(0.02),
|
48 |
+
output_layer_init_method=unscaled_init_method(0.02)
|
49 |
+
):
|
50 |
+
super(AttentionMixin, self).__init__()
|
51 |
+
self.num_layers = num_layers # replace attention in the LAST n layers
|
52 |
+
self.query_key_value = torch.nn.ModuleList(
|
53 |
+
[ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3,
|
54 |
+
gather_output=False, init_method=init_method)
|
55 |
+
for layer_id in range(num_layers)
|
56 |
+
])
|
57 |
+
self.dense = torch.nn.ModuleList(
|
58 |
+
[RowParallelLinear(hidden_size,
|
59 |
+
hidden_size,
|
60 |
+
input_is_parallel=True,
|
61 |
+
init_method=output_layer_init_method)
|
62 |
+
for layer_id in range(num_layers)
|
63 |
+
])
|
64 |
+
|
65 |
+
def reinit(self, parent_model=None):
|
66 |
+
start_layer = len(self.transformer.layers) - self.num_layers
|
67 |
+
assert start_layer >= 0
|
68 |
+
for layer_id in range(self.num_layers):
|
69 |
+
old_attention = self.transformer.layers[start_layer + layer_id].attention
|
70 |
+
self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
|
71 |
+
self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
|
72 |
+
self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
|
73 |
+
self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)
|
74 |
+
|
75 |
+
class DsrModel(BaseModel):
|
76 |
+
def __init__(self, args, transformer=None):
|
77 |
+
super().__init__(args, transformer=transformer)
|
78 |
+
self.original_sequence_length = args.max_sequence_length
|
79 |
+
additional_seqlen = args.new_sequence_length - args.max_sequence_length
|
80 |
+
self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
|
81 |
+
additional_seqlen, args.hidden_size
|
82 |
+
))
|
83 |
+
self.add_mixin('attention_plus', AttentionMixin(
|
84 |
+
num_layers=args.num_layers,
|
85 |
+
hidden_size=args.hidden_size
|
86 |
+
))
|
87 |
+
self.layout = args.layout
|
88 |
+
# [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
|
89 |
+
self.kernel_size = args.kernel_size
|
90 |
+
self.kernel_size2 = args.kernel_size2
|
91 |
+
self.log_attention_weights = None
|
92 |
+
|
93 |
+
def position_embedding_forward(self, position_ids, **kw_args):
|
94 |
+
position = position_ids[..., :self.layout[1]]
|
95 |
+
position_plus = position_ids[..., self.layout[1]:] - self.original_sequence_length
|
96 |
+
position_embeddings = torch.cat(
|
97 |
+
(
|
98 |
+
self.transformer.position_embeddings(position),
|
99 |
+
self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
|
100 |
+
),
|
101 |
+
dim=-2
|
102 |
+
)
|
103 |
+
return position_embeddings
|
104 |
+
|
105 |
+
def attention_forward(self, hidden_states, mask,
|
106 |
+
layer_id=None, log_attention_weights=None, **kw_args):
|
107 |
+
attn_module = self.transformer.layers[layer_id].attention
|
108 |
+
# attention_plus on all layers
|
109 |
+
query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id]
|
110 |
+
dense_plus = self.get_mixin('attention_plus').dense[layer_id]
|
111 |
+
# split two parts
|
112 |
+
hidden_states_plus = hidden_states[:, self.layout[1]:]
|
113 |
+
hidden_states = hidden_states[:, :self.layout[1]]
|
114 |
+
# base model qkv
|
115 |
+
mixed_raw_layer = attn_module.query_key_value(hidden_states)
|
116 |
+
q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
|
117 |
+
# cuda2d model qkv
|
118 |
+
mixed_raw_layer = query_key_value_plus(hidden_states_plus)
|
119 |
+
q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3)
|
120 |
+
|
121 |
+
dropout_fn = attn_module.attention_dropout if self.training else None
|
122 |
+
|
123 |
+
# cuda2d attention
|
124 |
+
context_layer0, context_layer1 = sparse_attention_2d_light(
|
125 |
+
q0, k0, v0,
|
126 |
+
q1, k1, v1,
|
127 |
+
mask,
|
128 |
+
n_head=attn_module.num_attention_heads_per_partition,
|
129 |
+
text_len=self.layout[0],
|
130 |
+
kernel_size=self.kernel_size,
|
131 |
+
kernel_size2=self.kernel_size2,
|
132 |
+
attention_dropout=dropout_fn,
|
133 |
+
log_attention_weights=log_attention_weights,
|
134 |
+
add_scalar=(kw_args['add_scalar'] if 'add_scalar' in kw_args else 0)
|
135 |
+
)
|
136 |
+
|
137 |
+
output_0 = attn_module.dense(context_layer0)
|
138 |
+
output_1 = dense_plus(context_layer1)
|
139 |
+
output = torch.cat((output_0, output_1), dim=1)
|
140 |
+
|
141 |
+
return output
|
142 |
+
|
143 |
+
def final_forward(self, logits, **kwargs):
|
144 |
+
logits_parallel = logits
|
145 |
+
logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
|
146 |
+
# logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
|
147 |
+
return logits_parallel
|
148 |
+
|
149 |
+
def disable_untrainable_params(self):
|
150 |
+
self.transformer.requires_grad_(False)
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def add_model_specific_args(cls, parser):
|
154 |
+
group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
|
155 |
+
group.add_argument("--kernel-size", type=int, default=5)
|
156 |
+
group.add_argument("--kernel-size2", type=int, default=5)
|
157 |
+
group.add_argument("--layout", type=str, default='96,496,4096')
|
158 |
+
group.add_argument("--new-sequence-length", type=int, default=4096)
|
159 |
+
return parser
|
160 |
+
|
161 |
+
def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, add_scalar=0, **kwargs):
|
162 |
+
'''
|
163 |
+
q0, k0, v0: [batch_size, 1088, hidden_size]
|
164 |
+
q1, k1, v1: [batch_size, 4096, h2]
|
165 |
+
n_head: int
|
166 |
+
attention_mask: [batch_size, 1088, 1088]
|
167 |
+
'''
|
168 |
+
from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
|
169 |
+
|
170 |
+
b, s0, h0 = q0.shape
|
171 |
+
b, s1, h1 = q1.shape
|
172 |
+
h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1)
|
173 |
+
|
174 |
+
q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
175 |
+
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
176 |
+
k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
|
177 |
+
|
178 |
+
# standard attention for level 0
|
179 |
+
attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
|
180 |
+
|
181 |
+
if log_attention_weights is not None:
|
182 |
+
attention_scores += log_attention_weights
|
183 |
+
attention_scores = torch.mul(attention_scores, attention_mask) - \
|
184 |
+
10000.0 * (1.0 - attention_mask)
|
185 |
+
|
186 |
+
attention_probs0 = F.softmax(attention_scores, dim=-1)
|
187 |
+
|
188 |
+
# local attention for level 1
|
189 |
+
q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
|
190 |
+
k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
|
191 |
+
v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
|
192 |
+
# scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)
|
193 |
+
scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
|
194 |
+
|
195 |
+
# cross attention
|
196 |
+
k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
|
197 |
+
scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
|
198 |
+
scores_1 = torch.cat(
|
199 |
+
(
|
200 |
+
scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]) + add_scalar,
|
201 |
+
scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
|
202 |
+
),
|
203 |
+
dim=-1)
|
204 |
+
attention_probs1 = F.softmax(scores_1, dim=-1)
|
205 |
+
|
206 |
+
if attention_dropout is not None:
|
207 |
+
# with get_cuda_rng_tracker().fork():
|
208 |
+
attention_probs0 = attention_dropout(attention_probs0)
|
209 |
+
attention_probs1 = attention_dropout(attention_probs1)
|
210 |
+
|
211 |
+
# weighting for level 0
|
212 |
+
context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
|
213 |
+
# weighting for level 1
|
214 |
+
probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
|
215 |
+
# context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
|
216 |
+
context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
|
217 |
+
|
218 |
+
context1 = context1_to_1.view(b, n_head * h, l1**2)
|
219 |
+
# weighting for cross attention
|
220 |
+
probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
|
221 |
+
v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
|
222 |
+
context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
|
223 |
+
context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
|
224 |
+
context1 = context1 + context1_to_0
|
225 |
+
return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_sampling.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@File : cuda2d_sampling.py
|
4 |
+
@Time : 2021/10/09 00:46:04
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
"""
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
from cv2 import reduce
|
15 |
+
import torch
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
|
22 |
+
def top_k_logits_(logits, top_k=0, filter_value=-float("Inf")):
|
23 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
24 |
+
logits[indices_to_remove] = filter_value
|
25 |
+
return logits
|
26 |
+
|
27 |
+
|
28 |
+
class IterativeEntfilterStrategy:
|
29 |
+
def __init__(self, invalid_slices=[], temperature=1.0, topk=6):
|
30 |
+
self.invalid_slices = invalid_slices
|
31 |
+
self.temperature = temperature
|
32 |
+
self.topk = topk
|
33 |
+
device = "cpu"
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
device = "cuda"
|
36 |
+
self.cluster_labels = torch.tensor(
|
37 |
+
np.load("cluster_label2.npy"), device=device, dtype=torch.long
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(
|
41 |
+
self,
|
42 |
+
logits_,
|
43 |
+
tokens,
|
44 |
+
temperature=None,
|
45 |
+
entfilter=None,
|
46 |
+
filter_topk=5,
|
47 |
+
temperature2=None,
|
48 |
+
):
|
49 |
+
# In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
|
50 |
+
if temperature is None:
|
51 |
+
temperature = self.temperature
|
52 |
+
|
53 |
+
logits = logits_.float() / temperature
|
54 |
+
for invalid_slice in self.invalid_slices:
|
55 |
+
logits[..., invalid_slice] = -float("Inf")
|
56 |
+
logits = logits.view(-1, logits.shape[-1])
|
57 |
+
|
58 |
+
rprobs = F.softmax(logits.float(), dim=-1)
|
59 |
+
c = self.cluster_labels.expand(*rprobs.shape)
|
60 |
+
cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(
|
61 |
+
1, c, rprobs
|
62 |
+
)
|
63 |
+
|
64 |
+
best_scores, best_clusters = cprobs.topk(self.topk)
|
65 |
+
bz = logits.shape[0]
|
66 |
+
best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
|
67 |
+
sampled_ids = torch.multinomial(best_scores, num_samples=1)
|
68 |
+
selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
|
69 |
+
selected_mask = (
|
70 |
+
self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters
|
71 |
+
) # cluster_labels [1, 20000] \in [0,500)
|
72 |
+
logits[selected_mask] = -65504
|
73 |
+
# for i in range(bz):
|
74 |
+
# selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
|
75 |
+
# logits[i, self.cluster_labels != selected_cluster] = -65504
|
76 |
+
|
77 |
+
# logits = top_k_logits(logits, self.topk, self.top_p)
|
78 |
+
probs = F.softmax(
|
79 |
+
logits.float() / 0.6, dim=-1
|
80 |
+
) # float is essetial, due to a bug in Pytorch
|
81 |
+
pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
|
82 |
+
|
83 |
+
assert tokens.shape[1] == pred.shape[1] + 1
|
84 |
+
tokens = torch.cat((tokens[:, :1], pred), dim=1)
|
85 |
+
return tokens
|
86 |
+
|
87 |
+
|
88 |
+
def filling_sequence_dsr(
|
89 |
+
model,
|
90 |
+
seq0,
|
91 |
+
seq1,
|
92 |
+
warmup_steps=3,
|
93 |
+
block_hw=(4, 4),
|
94 |
+
strategy=IterativeEntfilterStrategy(topk=10),
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
|
98 |
+
4095 {layout[2]} final_token.
|
99 |
+
Attention:
|
100 |
+
The sampling temperature are changing, temporally we hard code them here.
|
101 |
+
The temperature in the strategy is not used.
|
102 |
+
"""
|
103 |
+
assert hasattr(model, "layout")
|
104 |
+
layout = model.layout
|
105 |
+
assert (
|
106 |
+
len(seq0.shape) == 2 and len(seq1.shape) == 2 and seq0.shape[0] == seq1.shape[0]
|
107 |
+
)
|
108 |
+
assert len(layout) == 3
|
109 |
+
assert seq1.shape[1] == layout[-1] - layout[-2] + 1
|
110 |
+
assert (seq1 >= 0).all() and (seq0 >= 0).all()
|
111 |
+
device = seq0.device
|
112 |
+
# concat and pad sequences
|
113 |
+
batch_size = seq0.shape[0]
|
114 |
+
n_pad = layout[1] - seq0.shape[1]
|
115 |
+
assert n_pad > 0, "You should truncate long input before filling."
|
116 |
+
seq = torch.cat(
|
117 |
+
(
|
118 |
+
torch.tensor([0] * n_pad, device=device, dtype=seq0.dtype)
|
119 |
+
.unsqueeze(0)
|
120 |
+
.expand(batch_size, n_pad),
|
121 |
+
seq0,
|
122 |
+
seq1,
|
123 |
+
),
|
124 |
+
dim=1,
|
125 |
+
) # [b, layout[-1]+1]
|
126 |
+
assert seq.shape[1] == layout[-1] + 1
|
127 |
+
|
128 |
+
# build initial tokens, attention_mask, and position_ids
|
129 |
+
tokens = seq.clone()
|
130 |
+
attention_mask = torch.ones(layout[1], layout[1]).to(device)
|
131 |
+
attention_mask[: layout[0], layout[0] :] = 0
|
132 |
+
attention_mask[n_pad:, :n_pad] = 0
|
133 |
+
attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
|
134 |
+
position_ids = torch.cat(
|
135 |
+
(
|
136 |
+
torch.zeros(n_pad, dtype=torch.long),
|
137 |
+
torch.arange(0, layout[0] - n_pad),
|
138 |
+
torch.arange(513, 513 + layout[1] - layout[0]),
|
139 |
+
torch.arange(1024, 1024 + layout[2] - layout[1]),
|
140 |
+
)
|
141 |
+
).to(device)
|
142 |
+
log_attention_weights = torch.zeros(layout[1], layout[1], device=device).type_as(
|
143 |
+
next(model.parameters())
|
144 |
+
)
|
145 |
+
log_attention_weights[layout[0] :, n_pad : layout[0]] = 0.0
|
146 |
+
|
147 |
+
# prepare for interation
|
148 |
+
unfixed = tokens < 0 # just init an all-False tensor
|
149 |
+
unfixed[:, -layout[-1] + layout[-2] :] = True
|
150 |
+
|
151 |
+
ll, rr = block_hw
|
152 |
+
edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
|
153 |
+
num_steps = warmup_steps + ll - 1 + rr
|
154 |
+
# interative refining
|
155 |
+
|
156 |
+
# unfixed[..., -(layout[-1] - layout[-2]):].view(
|
157 |
+
# batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
|
158 |
+
|
159 |
+
ret = []
|
160 |
+
ret.append(tokens[:, layout[-2] + 1 :].clone())
|
161 |
+
for step_cnt in range(1, num_steps + 1):
|
162 |
+
if step_cnt <= warmup_steps:
|
163 |
+
logits, *_dump = model(
|
164 |
+
tokens[:, :-1],
|
165 |
+
position_ids,
|
166 |
+
attention_mask,
|
167 |
+
log_attention_weights=log_attention_weights,
|
168 |
+
)
|
169 |
+
real_temp = 1.0
|
170 |
+
new_tokens = strategy.forward(logits, tokens, real_temp)
|
171 |
+
tokens[unfixed] = new_tokens[unfixed]
|
172 |
+
else:
|
173 |
+
logits, *_dump = model(
|
174 |
+
tokens[:, :-1],
|
175 |
+
position_ids,
|
176 |
+
attention_mask,
|
177 |
+
log_attention_weights=log_attention_weights,
|
178 |
+
)
|
179 |
+
real_temp = 1.0
|
180 |
+
new_tokens = strategy.forward(
|
181 |
+
logits,
|
182 |
+
tokens,
|
183 |
+
real_temp,
|
184 |
+
entfilter=1.3,
|
185 |
+
filter_topk=5,
|
186 |
+
temperature2=0.6,
|
187 |
+
)
|
188 |
+
# tokens[unfixed] = new_tokens[unfixed]
|
189 |
+
# fixed tokens (update unfixed)
|
190 |
+
unfixed2 = tokens > 10000000
|
191 |
+
for x in range(min(ll, step_cnt - warmup_steps)):
|
192 |
+
y = step_cnt - warmup_steps - x - 1
|
193 |
+
if y < rr:
|
194 |
+
unfixed[..., -(layout[-1] - layout[-2]) :].view(
|
195 |
+
batch_size, edge_len // ll, ll, edge_len // rr, rr
|
196 |
+
)[:, :, x, :, y] = False
|
197 |
+
unfixed2[..., -(layout[-1] - layout[-2]) :].view(
|
198 |
+
batch_size, edge_len // ll, ll, edge_len // rr, rr
|
199 |
+
)[:, :, x, :, y] = True
|
200 |
+
tokens[unfixed2] = new_tokens[unfixed2]
|
201 |
+
|
202 |
+
ret.append(tokens[:, layout[-2] + 1 :].clone())
|
203 |
+
|
204 |
+
return ret
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/iterative_sr.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : iterative_sr.py
|
4 |
+
@Time : 2022/03/02 15:57:45
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
|
15 |
+
# here put the import lib
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
import math
|
19 |
+
import random
|
20 |
+
from PIL import ImageEnhance, Image
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import argparse
|
24 |
+
from torchvision import transforms
|
25 |
+
|
26 |
+
from SwissArmyTransformer.training.model_io import load_checkpoint
|
27 |
+
from SwissArmyTransformer import get_args
|
28 |
+
from .itersr_sampling import filling_sequence_itersr, IterativeEntfilterStrategy
|
29 |
+
from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
|
30 |
+
|
31 |
+
from .itersr_model import ItersrModel
|
32 |
+
|
33 |
+
from videogen_hub.depend.icetk import icetk as tokenizer
|
34 |
+
|
35 |
+
class IterativeSuperResolution:
|
36 |
+
def __init__(self, args, path, max_bz=4, shared_transformer=None):
|
37 |
+
args.load = path
|
38 |
+
args.kernel_size = 5
|
39 |
+
args.kernel_size2 = 5
|
40 |
+
args.new_sequence_length = 4624
|
41 |
+
args.layout = [16,3616]
|
42 |
+
|
43 |
+
model = ItersrModel(args, transformer=shared_transformer)
|
44 |
+
if args.fp16:
|
45 |
+
model = model.half()
|
46 |
+
|
47 |
+
load_checkpoint(model, args) # on cpu
|
48 |
+
model.eval()
|
49 |
+
self.model = model.cuda()
|
50 |
+
|
51 |
+
# save cpu weights
|
52 |
+
self.saved_weights = dict((k,v.cpu())
|
53 |
+
for k, v in model.named_parameters()
|
54 |
+
if 'transformer' in k
|
55 |
+
)
|
56 |
+
|
57 |
+
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
|
58 |
+
|
59 |
+
self.strategy = IterativeEntfilterStrategy(invalid_slices,
|
60 |
+
temperature=args.temp_all_itersr, topk=args.topk_itersr)
|
61 |
+
self.max_bz = max_bz
|
62 |
+
|
63 |
+
def _restore_transformer_from_cpu(self, non_blocking=False):
|
64 |
+
for k, v in self.model.named_parameters():
|
65 |
+
if k in self.saved_weights:
|
66 |
+
v.copy_(self.saved_weights[k])
|
67 |
+
|
68 |
+
def __call__(self, text_tokens, image_tokens, enhance=False, input_mask=None):
|
69 |
+
if len(text_tokens.shape) == 1:
|
70 |
+
text_tokens.unsqueeze_(0)
|
71 |
+
text_tokens = text_tokens.clone()[..., :16]
|
72 |
+
if len(image_tokens.shape) == 1:
|
73 |
+
image_tokens.unsqueeze_(0)
|
74 |
+
if enhance:
|
75 |
+
new_image_tokens = []
|
76 |
+
for big_img in image_tokens:
|
77 |
+
decoded = tokenizer.decode(image_ids=big_img).squeeze(0)
|
78 |
+
ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
79 |
+
image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
|
80 |
+
big_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
|
81 |
+
new_image_tokens.append(big_img2)
|
82 |
+
image_tokens = torch.stack(new_image_tokens)
|
83 |
+
print('Converting Itersr model...')
|
84 |
+
self._restore_transformer_from_cpu()
|
85 |
+
model = self.model
|
86 |
+
print('iterative super-resolution...')
|
87 |
+
output_list = []
|
88 |
+
for tim in range(max(text_tokens.shape[0] // self.max_bz, 1)):
|
89 |
+
big_img = image_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
|
90 |
+
text_seq = text_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
|
91 |
+
mask_raw = torch.tensor(
|
92 |
+
[
|
93 |
+
-1, 0, 1, 2, 3, 4,
|
94 |
+
0, -1, 2, -1, -2, 5,
|
95 |
+
1, -2, 3, 4, 5, 6,
|
96 |
+
2, 3, 4, 5, -1, 1,
|
97 |
+
3, -1, -2, 0, -1, 2,
|
98 |
+
4, 5, 6, 1, 3, -2
|
99 |
+
]
|
100 |
+
).view(1, 6, 1, 6).expand(10, 6, 10, 6).reshape(-1).contiguous()
|
101 |
+
|
102 |
+
topks = [60, 40, 40, 40, 20, 20, 10]
|
103 |
+
|
104 |
+
for mask_ratio in range(1, 7):
|
105 |
+
self.strategy.topk = topks[mask_ratio]
|
106 |
+
mask = (mask_raw.to(big_img.device) >= mask_ratio)
|
107 |
+
if input_mask is not None:
|
108 |
+
mask = mask & input_mask
|
109 |
+
big_img.masked_fill_(mask, tokenizer['<start_of_image>'])
|
110 |
+
seq1 = big_img
|
111 |
+
output1 = filling_sequence_itersr(model, text_seq, seq1,
|
112 |
+
warmup_steps=1, block_hw=(1, 0),
|
113 |
+
strategy=self.strategy
|
114 |
+
)
|
115 |
+
big_img = output1
|
116 |
+
print(f'Iter {mask_ratio} times.')
|
117 |
+
output_list.append(output1.clone())
|
118 |
+
return torch.cat(output_list, dim=0)
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_model.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : itersr_model.py
|
4 |
+
@Time : 2021/10/02 01:36:32
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
|
18 |
+
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
|
19 |
+
|
20 |
+
from SwissArmyTransformer.mpu.utils import sqrt
|
21 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
|
22 |
+
from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
|
23 |
+
from SwissArmyTransformer.model.transformer import unscaled_init_method, split_tensor_along_last_dim
|
24 |
+
|
25 |
+
class PositionEmbeddingMixin(BaseMixin):
|
26 |
+
def __init__(self, additional_sequence_length, hidden_size,
|
27 |
+
init_method_std=0.02, reinit_slice=slice(512, 512+400)
|
28 |
+
):
|
29 |
+
super(PositionEmbeddingMixin, self).__init__()
|
30 |
+
self.reinit_slice = reinit_slice
|
31 |
+
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
|
32 |
+
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
|
33 |
+
|
34 |
+
def reinit(self, parent_model=None):
|
35 |
+
old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
|
36 |
+
old_len, hidden_size = old_weights.shape
|
37 |
+
assert hidden_size == self.position_embeddings.weight.shape[-1]
|
38 |
+
old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
|
39 |
+
assert new_edge % old_edge == 0
|
40 |
+
self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
|
41 |
+
|
42 |
+
class ItersrModel(BaseModel):
|
43 |
+
def __init__(self, args, transformer=None):
|
44 |
+
super().__init__(args, transformer=transformer)
|
45 |
+
self.original_sequence_length = args.max_sequence_length
|
46 |
+
additional_seqlen = args.new_sequence_length - args.max_sequence_length
|
47 |
+
self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
|
48 |
+
additional_seqlen, args.hidden_size
|
49 |
+
))
|
50 |
+
# self.add_mixin('attention_plus', AttentionMixin(
|
51 |
+
# num_layers=args.num_layers,
|
52 |
+
# hidden_size=args.hidden_size
|
53 |
+
# ))
|
54 |
+
self.layout = args.layout
|
55 |
+
# [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
|
56 |
+
self.kernel_size = args.kernel_size
|
57 |
+
self.kernel_size2 = args.kernel_size2
|
58 |
+
self.log_attention_weights = None
|
59 |
+
|
60 |
+
def position_embedding_forward(self, position_ids, **kw_args):
|
61 |
+
position = position_ids[..., :self.layout[0]]
|
62 |
+
position_plus = position_ids[..., self.layout[0]:] - self.original_sequence_length
|
63 |
+
position_embeddings = torch.cat(
|
64 |
+
(
|
65 |
+
self.transformer.position_embeddings(position),
|
66 |
+
self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
|
67 |
+
),
|
68 |
+
dim=-2
|
69 |
+
)
|
70 |
+
return position_embeddings
|
71 |
+
|
72 |
+
def attention_forward(self, hidden_states, mask,
|
73 |
+
layer_id=None, log_attention_weights=None, **kw_args):
|
74 |
+
attn_module = self.transformer.layers[layer_id].attention
|
75 |
+
# base model qkv
|
76 |
+
mixed_raw_layer = attn_module.query_key_value(hidden_states)
|
77 |
+
q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer[:, :self.layout[0]], 3)
|
78 |
+
# cuda2d model qkv
|
79 |
+
q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer[:, self.layout[0]:], 3)
|
80 |
+
|
81 |
+
dropout_fn = attn_module.attention_dropout if self.training else None
|
82 |
+
|
83 |
+
# cuda2d attention
|
84 |
+
context_layer = sparse_attention_2d_text(
|
85 |
+
q0, k0, v0,
|
86 |
+
q1, k1, v1,
|
87 |
+
mask,
|
88 |
+
n_head=attn_module.num_attention_heads_per_partition,
|
89 |
+
text_len=self.layout[0],
|
90 |
+
kernel_size=self.kernel_size,
|
91 |
+
attention_dropout=dropout_fn,
|
92 |
+
log_attention_weights=log_attention_weights,
|
93 |
+
)
|
94 |
+
|
95 |
+
output = attn_module.dense(context_layer)
|
96 |
+
|
97 |
+
return output
|
98 |
+
|
99 |
+
def final_forward(self, logits, **kwargs):
|
100 |
+
logits_parallel = logits
|
101 |
+
logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]).float()
|
102 |
+
# logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
|
103 |
+
return logits_parallel
|
104 |
+
|
105 |
+
# def disable_untrainable_params(self):
|
106 |
+
# self.transformer.requires_grad_(False)
|
107 |
+
|
108 |
+
@classmethod
|
109 |
+
def add_model_specific_args(cls, parser):
|
110 |
+
group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
|
111 |
+
group.add_argument("--kernel-size", type=int, default=5)
|
112 |
+
group.add_argument("--kernel-size2", type=int, default=5)
|
113 |
+
group.add_argument("--layout", type=str, default='16,3616')
|
114 |
+
group.add_argument("--new-sequence-length", type=int, default=4096)
|
115 |
+
return parser
|
116 |
+
|
117 |
+
def sparse_attention_2d_text(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
|
118 |
+
'''
|
119 |
+
q0, k0, v0: [batch_size, 16, hidden_size]
|
120 |
+
q1, k1, v1: [batch_size, 3600, hidden_size]
|
121 |
+
n_head: int
|
122 |
+
attention_mask: [batch_size, 16]
|
123 |
+
'''
|
124 |
+
from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
|
125 |
+
b, s0, h0 = q0.shape
|
126 |
+
b, s1, h1 = q1.shape
|
127 |
+
h, l1 = h0 // n_head, sqrt(s1)
|
128 |
+
assert attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
|
129 |
+
|
130 |
+
q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
131 |
+
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
132 |
+
k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
|
133 |
+
|
134 |
+
# standard attention for level 0
|
135 |
+
attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
|
136 |
+
|
137 |
+
attention_scores = torch.mul(attention_scores, attention_mask) - \
|
138 |
+
10000.0 * (1.0 - attention_mask)
|
139 |
+
|
140 |
+
attention_probs0 = F.softmax(attention_scores, dim=-1)
|
141 |
+
|
142 |
+
# local attention for level 1
|
143 |
+
q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
|
144 |
+
k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
|
145 |
+
v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
|
146 |
+
scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
|
147 |
+
|
148 |
+
# cross attention
|
149 |
+
scores_1_to_0 = torch.matmul(q1.view(b, n_head, h, s1).transpose(-1, -2), k0T)
|
150 |
+
if log_attention_weights is not None:
|
151 |
+
scores_1_to_0 += log_attention_weights
|
152 |
+
scores_1_to_0 = torch.mul(scores_1_to_0, attention_mask) - \
|
153 |
+
10000.0 * (1.0 - attention_mask)
|
154 |
+
scores_1 = torch.cat(
|
155 |
+
(
|
156 |
+
scores_1_to_0.view(b*n_head, s1, s0),
|
157 |
+
scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
|
158 |
+
),
|
159 |
+
dim=-1)
|
160 |
+
attention_probs1 = F.softmax(scores_1, dim=-1)
|
161 |
+
|
162 |
+
if attention_dropout is not None:
|
163 |
+
with get_cuda_rng_tracker().fork():
|
164 |
+
attention_probs1 = attention_dropout(attention_probs1)
|
165 |
+
|
166 |
+
# weighting for level 0
|
167 |
+
context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
|
168 |
+
# weighting for level 1
|
169 |
+
probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
|
170 |
+
context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
|
171 |
+
|
172 |
+
context1 = context1_to_1.view(b, n_head, h, l1**2)
|
173 |
+
# weighting for cross attention
|
174 |
+
probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view(b, n_head, -1, scores_1_to_0.shape[3])
|
175 |
+
|
176 |
+
context1_to_0 = torch.matmul(probs_1_to_0, v0)
|
177 |
+
context1 = context1.transpose(-1, -2) + context1_to_0
|
178 |
+
|
179 |
+
output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
|
180 |
+
|
181 |
+
return output
|
182 |
+
|
183 |
+
def sparse_attention_2d_notext(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
|
184 |
+
'''
|
185 |
+
q0, k0, v0: [batch_size, 16, hidden_size]
|
186 |
+
q1, k1, v1: [batch_size, 3600, hidden_size]
|
187 |
+
n_head: int
|
188 |
+
attention_mask: [batch_size, 16]
|
189 |
+
'''
|
190 |
+
from SwissArmyTransformer.mpu.local_attention_function import f_similar, f_weighting
|
191 |
+
b, s0, h0 = q0.shape
|
192 |
+
b, s1, h1 = q1.shape
|
193 |
+
h, l1 = h0 // n_head, sqrt(s1)
|
194 |
+
assert len(attention_mask.shape) == 4 and attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
|
195 |
+
|
196 |
+
q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
197 |
+
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
198 |
+
k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
|
199 |
+
|
200 |
+
# standard attention for level 0
|
201 |
+
attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
|
202 |
+
|
203 |
+
attention_scores = torch.mul(attention_scores, attention_mask) - \
|
204 |
+
10000.0 * (1.0 - attention_mask)
|
205 |
+
|
206 |
+
attention_probs0 = F.softmax(attention_scores, dim=-1)
|
207 |
+
|
208 |
+
# local attention for level 1
|
209 |
+
q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
|
210 |
+
k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
|
211 |
+
v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
|
212 |
+
scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
|
213 |
+
|
214 |
+
attention_probs1 = F.softmax(scores_1_to_1, dim=-1)
|
215 |
+
|
216 |
+
if attention_dropout is not None:
|
217 |
+
with get_cuda_rng_tracker().fork():
|
218 |
+
attention_probs1 = attention_dropout(attention_probs1)
|
219 |
+
|
220 |
+
# weighting for level 0
|
221 |
+
context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
|
222 |
+
# weighting for level 1
|
223 |
+
probs_1_to_1 = attention_probs1
|
224 |
+
context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
|
225 |
+
|
226 |
+
context1 = context1_to_1.view(b, n_head, h, l1**2)
|
227 |
+
# weighting for cross attention
|
228 |
+
context1 = context1.transpose(-1, -2)
|
229 |
+
|
230 |
+
output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
|
231 |
+
|
232 |
+
return output
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_sampling.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : itersr_sampling.py
|
4 |
+
@Time : 2022/03/03 14:24:28
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from videogen_hub.depend.icetk import icetk as tokenizer
|
19 |
+
|
20 |
+
def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
|
21 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
22 |
+
logits[indices_to_remove] = filter_value
|
23 |
+
return logits
|
24 |
+
|
25 |
+
# class IterativeEntfilterStrategy:
|
26 |
+
# def __init__(self, invalid_slices=[], temperature=1., topk=10):
|
27 |
+
# self.invalid_slices = invalid_slices
|
28 |
+
# self.temperature = temperature
|
29 |
+
# self.topk = topk
|
30 |
+
# self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
|
31 |
+
|
32 |
+
|
33 |
+
# def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
|
34 |
+
# # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
|
35 |
+
# if temperature is None:
|
36 |
+
# temperature = self.temperature
|
37 |
+
|
38 |
+
# logits = logits_.float() / temperature
|
39 |
+
# for invalid_slice in self.invalid_slices:
|
40 |
+
# logits[..., invalid_slice] = -float('Inf')
|
41 |
+
# logits = logits.view(-1, logits.shape[-1])
|
42 |
+
|
43 |
+
# rprobs = F.softmax(logits.float(), dim=-1)
|
44 |
+
# c = self.cluster_labels.expand(*rprobs.shape)
|
45 |
+
# cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
|
46 |
+
|
47 |
+
# best_scores, best_clusters = cprobs.topk(self.topk)
|
48 |
+
# bz = logits.shape[0]
|
49 |
+
# best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
|
50 |
+
# sampled_ids = torch.multinomial(best_scores, num_samples=1)
|
51 |
+
# selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
|
52 |
+
# selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500)
|
53 |
+
# logits[selected_mask] = -65504
|
54 |
+
# # for i in range(bz):
|
55 |
+
# # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
|
56 |
+
# # logits[i, self.cluster_labels != selected_cluster] = -65504
|
57 |
+
|
58 |
+
# # logits = top_k_logits(logits, self.topk, self.top_p)
|
59 |
+
# probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
|
60 |
+
# pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
|
61 |
+
|
62 |
+
# assert tokens.shape[1] == pred.shape[1]
|
63 |
+
# tokens = pred
|
64 |
+
# return tokens
|
65 |
+
|
66 |
+
class IterativeEntfilterStrategy:
|
67 |
+
def __init__(self, invalid_slices=[], temperature=1., topk=10):
|
68 |
+
self.invalid_slices = invalid_slices
|
69 |
+
self.temperature = temperature
|
70 |
+
self.topk = topk
|
71 |
+
|
72 |
+
def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
|
73 |
+
# In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
|
74 |
+
if temperature is None:
|
75 |
+
temperature = self.temperature
|
76 |
+
# check entropy filter
|
77 |
+
# if entfilter is not None:
|
78 |
+
# assert temperature2 is not None
|
79 |
+
# topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1)
|
80 |
+
# ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length]
|
81 |
+
# temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2
|
82 |
+
|
83 |
+
logits = logits.float() / temperature
|
84 |
+
for invalid_slice in self.invalid_slices:
|
85 |
+
logits[..., invalid_slice] = -float('Inf')
|
86 |
+
|
87 |
+
# debiased topk
|
88 |
+
# probs = F.softmax(logits, dim=-1)
|
89 |
+
# tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
|
90 |
+
# pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
|
91 |
+
# edge_idx = tk_idx[:, :, -1:]
|
92 |
+
# edge_value = tk_value[:, :, -1:]
|
93 |
+
# edge_mask = probs.gather(dim=-1, index=pred) < edge_value
|
94 |
+
# pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token
|
95 |
+
# pred.squeeze_(-1) # [batch_size, seq_length]
|
96 |
+
|
97 |
+
top_k_logits_(logits, self.topk)
|
98 |
+
probs = F.softmax(logits, dim=-1)
|
99 |
+
pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
|
100 |
+
pred.squeeze_(-1)
|
101 |
+
|
102 |
+
assert tokens.shape[1] == pred.shape[1]
|
103 |
+
tokens = pred
|
104 |
+
return tokens
|
105 |
+
|
106 |
+
def filling_sequence_itersr(
|
107 |
+
model,
|
108 |
+
seq0,
|
109 |
+
seq1,
|
110 |
+
warmup_steps=3,
|
111 |
+
block_hw=(4, 4),
|
112 |
+
strategy=IterativeEntfilterStrategy(topk=10),
|
113 |
+
):
|
114 |
+
'''
|
115 |
+
seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
|
116 |
+
4095 {layout[2]} final_token.
|
117 |
+
Attention:
|
118 |
+
The sampling temperature are changing, temporally we hard code them here.
|
119 |
+
The temperature in the strategy is not used.
|
120 |
+
'''
|
121 |
+
assert hasattr(model, 'layout')
|
122 |
+
layout = model.layout
|
123 |
+
|
124 |
+
device = seq0.device
|
125 |
+
# concat and pad sequences
|
126 |
+
batch_size = seq0.shape[0]
|
127 |
+
n_pad = layout[0] - seq0.shape[1]
|
128 |
+
assert n_pad >= 0, "You should truncate long input before filling."
|
129 |
+
seq = torch.cat((
|
130 |
+
torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
|
131 |
+
.unsqueeze(0).expand(batch_size, n_pad),
|
132 |
+
seq0, seq1), dim=1) # [b, layout[-1]+1]
|
133 |
+
assert seq.shape[1] == layout[-1]
|
134 |
+
|
135 |
+
# build initial tokens, attention_mask, and position_ids
|
136 |
+
tokens = seq.clone()
|
137 |
+
attention_mask = torch.ones(layout[0]).to(device)
|
138 |
+
attention_mask[:n_pad] = 0
|
139 |
+
attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16
|
140 |
+
position_ids = torch.cat((
|
141 |
+
torch.zeros(n_pad, dtype=torch.long),
|
142 |
+
torch.arange(0, layout[0] - n_pad),
|
143 |
+
torch.arange(1024, 1024+layout[1]-layout[0]))).to(device)
|
144 |
+
log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters()))
|
145 |
+
log_attention_weights[n_pad:layout[0]] = 0.
|
146 |
+
log_attention_weights = log_attention_weights.unsqueeze(0)
|
147 |
+
|
148 |
+
# prepare for interation
|
149 |
+
unfixed = (tokens == tokenizer['<start_of_image>'])
|
150 |
+
ll, rr = block_hw
|
151 |
+
edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
|
152 |
+
num_steps = 1
|
153 |
+
# interative refining
|
154 |
+
|
155 |
+
# unfixed[..., -(layout[-1] - layout[-2]):].view(
|
156 |
+
# batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
|
157 |
+
|
158 |
+
|
159 |
+
ret = []
|
160 |
+
# ret.append(tokens[:, layout[-2]:-1].clone())
|
161 |
+
for step_cnt in range(1, num_steps+1):
|
162 |
+
logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights)
|
163 |
+
real_temp = 1.
|
164 |
+
new_tokens = strategy.forward(logits, tokens, real_temp)
|
165 |
+
tokens[unfixed] = new_tokens[unfixed]
|
166 |
+
|
167 |
+
ret.append(tokens[:, layout[-2]:].clone())
|
168 |
+
return torch.cat(ret, dim=0)
|
src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/sr_group.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : sr_group.py
|
4 |
+
@Time : 2022/04/02 01:17:21
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from SwissArmyTransformer.resources import auto_create
|
19 |
+
from .direct_sr import DirectSuperResolution
|
20 |
+
from .iterative_sr import IterativeSuperResolution
|
21 |
+
|
22 |
+
class SRGroup:
|
23 |
+
def __init__(self, args, home_path=None,):
|
24 |
+
dsr_path = auto_create('cogview2-dsr', path=home_path)
|
25 |
+
itersr_path = auto_create('cogview2-itersr', path=home_path)
|
26 |
+
dsr = DirectSuperResolution(args, dsr_path)
|
27 |
+
itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
|
28 |
+
self.dsr = dsr
|
29 |
+
self.itersr = itersr
|
30 |
+
|
31 |
+
def sr_base(self, img_tokens, txt_tokens):
|
32 |
+
assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
|
33 |
+
batch_size = img_tokens.shape[0]
|
34 |
+
txt_len = txt_tokens.shape[-1]
|
35 |
+
if len(txt_tokens.shape) == 1:
|
36 |
+
txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
|
37 |
+
sred_tokens = self.dsr(txt_tokens, img_tokens)
|
38 |
+
iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
|
39 |
+
return iter_tokens[-batch_size:]
|
40 |
+
|
41 |
+
# def sr_patch(self, img_tokens, txt_tokens):
|
42 |
+
# assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2
|
43 |
+
# batch_size = img_tokens.shape[0] * 9
|
44 |
+
# txt_len = txt_tokens.shape[-1]
|
45 |
+
# if len(txt_tokens.shape) == 1:
|
46 |
+
# txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
|
47 |
+
# img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400)
|
48 |
+
# iter_tokens = self.sr_base(img_tokens, txt_tokens)
|
49 |
+
# return iter_tokens
|
src/videogen_hub/pipelines/consisti2v/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 TIGER Lab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
src/videogen_hub/pipelines/consisti2v/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/configs/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/configs/inference/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/configs/inference/inference.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_dir: "samples/inference"
|
2 |
+
output_name: "i2v"
|
3 |
+
|
4 |
+
pretrained_model_path: "TIGER-Lab/ConsistI2V"
|
5 |
+
unet_path: null
|
6 |
+
unet_ckpt_prefix: "module."
|
7 |
+
pipeline_pretrained_path: null
|
8 |
+
|
9 |
+
sampling_kwargs:
|
10 |
+
height: 256
|
11 |
+
width: 256
|
12 |
+
n_frames: 16
|
13 |
+
steps: 50
|
14 |
+
ddim_eta: 0.0
|
15 |
+
guidance_scale_txt: 7.5
|
16 |
+
guidance_scale_img: 1.0
|
17 |
+
guidance_rescale: 0.0
|
18 |
+
num_videos_per_prompt: 1
|
19 |
+
frame_stride: 3
|
20 |
+
|
21 |
+
unet_additional_kwargs:
|
22 |
+
variant: null
|
23 |
+
n_temp_heads: 8
|
24 |
+
augment_temporal_attention: true
|
25 |
+
temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
|
26 |
+
first_frame_condition_mode: "concat"
|
27 |
+
use_frame_stride_condition: true
|
28 |
+
noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
|
29 |
+
noise_alpha: 1.0
|
30 |
+
|
31 |
+
noise_scheduler_kwargs:
|
32 |
+
beta_start: 0.00085
|
33 |
+
beta_end: 0.012
|
34 |
+
beta_schedule: "linear"
|
35 |
+
steps_offset: 1
|
36 |
+
clip_sample: false
|
37 |
+
rescale_betas_zero_snr: false # true if using zero terminal snr
|
38 |
+
timestep_spacing: "leading" # "trailing" if using zero terminal snr
|
39 |
+
prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
|
40 |
+
|
41 |
+
frameinit_kwargs:
|
42 |
+
enable: true
|
43 |
+
camera_motion: null
|
44 |
+
noise_level: 850
|
45 |
+
filter_params:
|
46 |
+
method: 'gaussian'
|
47 |
+
d_s: 0.25
|
48 |
+
d_t: 0.25
|
src/videogen_hub/pipelines/consisti2v/configs/inference/inference_autoregress.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_dir: "samples/inference"
|
2 |
+
output_name: "long_video"
|
3 |
+
|
4 |
+
pretrained_model_path: "TIGER-Lab/ConsistI2V"
|
5 |
+
unet_path: null
|
6 |
+
unet_ckpt_prefix: "module."
|
7 |
+
pipeline_pretrained_path: null
|
8 |
+
|
9 |
+
sampling_kwargs:
|
10 |
+
height: 256
|
11 |
+
width: 256
|
12 |
+
n_frames: 16
|
13 |
+
steps: 50
|
14 |
+
ddim_eta: 0.0
|
15 |
+
guidance_scale_txt: 7.5
|
16 |
+
guidance_scale_img: 1.0
|
17 |
+
guidance_rescale: 0.0
|
18 |
+
num_videos_per_prompt: 1
|
19 |
+
frame_stride: 3
|
20 |
+
autoregress_steps: 3
|
21 |
+
|
22 |
+
unet_additional_kwargs:
|
23 |
+
variant: null
|
24 |
+
n_temp_heads: 8
|
25 |
+
augment_temporal_attention: true
|
26 |
+
temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
|
27 |
+
first_frame_condition_mode: "concat"
|
28 |
+
use_frame_stride_condition: true
|
29 |
+
noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
|
30 |
+
noise_alpha: 1.0
|
31 |
+
|
32 |
+
noise_scheduler_kwargs:
|
33 |
+
beta_start: 0.00085
|
34 |
+
beta_end: 0.012
|
35 |
+
beta_schedule: "linear"
|
36 |
+
steps_offset: 1
|
37 |
+
clip_sample: false
|
38 |
+
rescale_betas_zero_snr: false # true if using zero terminal snr
|
39 |
+
timestep_spacing: "leading" # "trailing" if using zero terminal snr
|
40 |
+
prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
|
41 |
+
|
42 |
+
|
43 |
+
frameinit_kwargs:
|
44 |
+
enable: true
|
45 |
+
noise_level: 850
|
46 |
+
filter_params:
|
47 |
+
method: 'gaussian'
|
48 |
+
d_s: 0.25
|
49 |
+
d_t: 0.25
|
src/videogen_hub/pipelines/consisti2v/configs/prompts/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/configs/prompts/default.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seeds: random
|
2 |
+
|
3 |
+
prompts:
|
4 |
+
- "timelapse at the snow land with aurora in the sky."
|
5 |
+
- "fireworks."
|
6 |
+
- "clown fish swimming through the coral reef."
|
7 |
+
- "melting ice cream dripping down the cone."
|
8 |
+
|
9 |
+
n_prompts:
|
10 |
+
- ""
|
11 |
+
|
12 |
+
path_to_first_frames:
|
13 |
+
- "assets/example/example_01.png"
|
14 |
+
- "assets/example/example_02.png"
|
15 |
+
- "assets/example/example_03.png"
|
16 |
+
- "assets/example/example_04.png"
|
src/videogen_hub/pipelines/consisti2v/configs/training/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/configs/training/training.yaml
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_dir: "checkpoints"
|
2 |
+
pretrained_model_path: "stabilityai/stable-diffusion-2-1-base"
|
3 |
+
|
4 |
+
noise_scheduler_kwargs:
|
5 |
+
num_train_timesteps: 1000
|
6 |
+
beta_start: 0.00085
|
7 |
+
beta_end: 0.012
|
8 |
+
beta_schedule: "linear"
|
9 |
+
steps_offset: 1
|
10 |
+
clip_sample: false
|
11 |
+
rescale_betas_zero_snr: false # true if using zero terminal snr
|
12 |
+
timestep_spacing: "leading" # "trailing" if using zero terminal snr
|
13 |
+
prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
|
14 |
+
|
15 |
+
train_data:
|
16 |
+
dataset: "joint"
|
17 |
+
pexels_config:
|
18 |
+
enable: false
|
19 |
+
json_path: null
|
20 |
+
caption_json_path: null
|
21 |
+
video_folder: null
|
22 |
+
webvid_config:
|
23 |
+
enable: true
|
24 |
+
json_path: "/path/to/webvid/annotation"
|
25 |
+
video_folder: "/path/to/webvid/data"
|
26 |
+
sample_size: 256
|
27 |
+
sample_duration: null
|
28 |
+
sample_fps: null
|
29 |
+
sample_stride: [1, 5]
|
30 |
+
sample_n_frames: 16
|
31 |
+
|
32 |
+
validation_data:
|
33 |
+
prompts:
|
34 |
+
- "timelapse at the snow land with aurora in the sky."
|
35 |
+
- "fireworks."
|
36 |
+
- "clown fish swimming through the coral reef."
|
37 |
+
- "melting ice cream dripping down the cone."
|
38 |
+
|
39 |
+
path_to_first_frames:
|
40 |
+
- "assets/example/example_01.jpg"
|
41 |
+
- "assets/example/example_02.jpg"
|
42 |
+
- "assets/example/example_03.jpg"
|
43 |
+
- "assets/example/example_04.jpg"
|
44 |
+
|
45 |
+
num_inference_steps: 50
|
46 |
+
ddim_eta: 0.0
|
47 |
+
guidance_scale_txt: 7.5
|
48 |
+
guidance_scale_img: 1.0
|
49 |
+
guidance_rescale: 0.0
|
50 |
+
frame_stride: 3
|
51 |
+
|
52 |
+
trainable_modules:
|
53 |
+
- "all"
|
54 |
+
# - "conv3ds."
|
55 |
+
# - "tempo_attns."
|
56 |
+
|
57 |
+
resume_from_checkpoint: null
|
58 |
+
|
59 |
+
unet_additional_kwargs:
|
60 |
+
variant: null
|
61 |
+
n_temp_heads: 8
|
62 |
+
augment_temporal_attention: true
|
63 |
+
temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
|
64 |
+
first_frame_condition_mode: "concat"
|
65 |
+
use_frame_stride_condition: true
|
66 |
+
noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
|
67 |
+
noise_alpha: 1.0
|
68 |
+
|
69 |
+
cfg_random_null_text_ratio: 0.1
|
70 |
+
cfg_random_null_img_ratio: 0.1
|
71 |
+
|
72 |
+
use_ema: false
|
73 |
+
ema_decay: 0.9999
|
74 |
+
|
75 |
+
learning_rate: 5.e-5
|
76 |
+
train_batch_size: 3
|
77 |
+
gradient_accumulation_steps: 1
|
78 |
+
max_grad_norm: 0.5
|
79 |
+
|
80 |
+
max_train_epoch: -1
|
81 |
+
max_train_steps: 200000
|
82 |
+
checkpointing_epochs: -1
|
83 |
+
checkpointing_steps: 2000
|
84 |
+
validation_steps: 1000
|
85 |
+
|
86 |
+
seed: 42
|
87 |
+
mixed_precision: "bf16"
|
88 |
+
num_workers: 32
|
89 |
+
enable_xformers_memory_efficient_attention: true
|
90 |
+
|
91 |
+
is_image: false
|
92 |
+
is_debug: false
|
src/videogen_hub/pipelines/consisti2v/consisti2v/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/consisti2v/data/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/consisti2v/data/dataset.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, io, csv, math, random
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
from einops import rearrange
|
5 |
+
from decord import VideoReader
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
|
11 |
+
from diffusers.utils import logging
|
12 |
+
|
13 |
+
logger = logging.get_logger(__name__)
|
14 |
+
|
15 |
+
class WebVid10M(Dataset):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
json_path, video_folder=None,
|
19 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
20 |
+
is_image=False,
|
21 |
+
**kwargs,
|
22 |
+
):
|
23 |
+
logger.info(f"loading annotations from {json_path} ...")
|
24 |
+
with open(json_path, 'rb') as json_file:
|
25 |
+
json_list = list(json_file)
|
26 |
+
self.dataset = [json.loads(json_str) for json_str in json_list]
|
27 |
+
self.length = len(self.dataset)
|
28 |
+
logger.info(f"data scale: {self.length}")
|
29 |
+
|
30 |
+
self.video_folder = video_folder
|
31 |
+
self.sample_stride = sample_stride if isinstance(sample_stride, int) else tuple(sample_stride)
|
32 |
+
self.sample_n_frames = sample_n_frames
|
33 |
+
self.is_image = is_image
|
34 |
+
|
35 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
36 |
+
self.pixel_transforms = transforms.Compose([
|
37 |
+
transforms.RandomHorizontalFlip(),
|
38 |
+
transforms.Resize(sample_size[0], antialias=None),
|
39 |
+
transforms.CenterCrop(sample_size),
|
40 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
41 |
+
])
|
42 |
+
|
43 |
+
def get_batch(self, idx):
|
44 |
+
video_dict = self.dataset[idx]
|
45 |
+
video_relative_path, name = video_dict['file'], video_dict['text']
|
46 |
+
|
47 |
+
if self.video_folder is not None:
|
48 |
+
if video_relative_path[0] == '/':
|
49 |
+
video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
|
50 |
+
else:
|
51 |
+
video_dir = os.path.join(self.video_folder, video_relative_path)
|
52 |
+
else:
|
53 |
+
video_dir = video_relative_path
|
54 |
+
video_reader = VideoReader(video_dir)
|
55 |
+
video_length = len(video_reader)
|
56 |
+
|
57 |
+
if not self.is_image:
|
58 |
+
if isinstance(self.sample_stride, int):
|
59 |
+
stride = self.sample_stride
|
60 |
+
elif isinstance(self.sample_stride, tuple):
|
61 |
+
stride = random.randint(self.sample_stride[0], self.sample_stride[1])
|
62 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
|
63 |
+
start_idx = random.randint(0, video_length - clip_length)
|
64 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
65 |
+
else:
|
66 |
+
frame_difference = random.randint(2, self.sample_n_frames)
|
67 |
+
clip_length = min(video_length, (frame_difference - 1) * self.sample_stride + 1)
|
68 |
+
start_idx = random.randint(0, video_length - clip_length)
|
69 |
+
batch_index = [start_idx, start_idx + clip_length - 1]
|
70 |
+
|
71 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
72 |
+
pixel_values = pixel_values / 255.
|
73 |
+
del video_reader
|
74 |
+
|
75 |
+
return pixel_values, name
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return self.length
|
79 |
+
|
80 |
+
def __getitem__(self, idx):
|
81 |
+
while True:
|
82 |
+
try:
|
83 |
+
pixel_values, name = self.get_batch(idx)
|
84 |
+
break
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
idx = random.randint(0, self.length-1)
|
88 |
+
|
89 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
90 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
91 |
+
return sample
|
92 |
+
|
93 |
+
|
94 |
+
class Pexels(Dataset):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
json_path, caption_json_path, video_folder=None,
|
98 |
+
sample_size=256, sample_duration=1, sample_fps=8,
|
99 |
+
is_image=False,
|
100 |
+
**kwargs,
|
101 |
+
):
|
102 |
+
logger.info(f"loading captions from {caption_json_path} ...")
|
103 |
+
with open(caption_json_path, 'rb') as caption_json_file:
|
104 |
+
caption_json_list = list(caption_json_file)
|
105 |
+
self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
|
106 |
+
|
107 |
+
logger.info(f"loading annotations from {json_path} ...")
|
108 |
+
with open(json_path, 'rb') as json_file:
|
109 |
+
json_list = list(json_file)
|
110 |
+
dataset = [json.loads(json_str) for json_str in json_list]
|
111 |
+
|
112 |
+
self.dataset = []
|
113 |
+
for data in dataset:
|
114 |
+
data['text'] = self.caption_dict[data['id']]
|
115 |
+
if data['height'] / data['width'] < 0.625:
|
116 |
+
self.dataset.append(data)
|
117 |
+
self.length = len(self.dataset)
|
118 |
+
logger.info(f"data scale: {self.length}")
|
119 |
+
|
120 |
+
self.video_folder = video_folder
|
121 |
+
self.sample_duration = sample_duration
|
122 |
+
self.sample_fps = sample_fps
|
123 |
+
self.sample_n_frames = sample_duration * sample_fps
|
124 |
+
self.is_image = is_image
|
125 |
+
|
126 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
127 |
+
self.pixel_transforms = transforms.Compose([
|
128 |
+
transforms.RandomHorizontalFlip(),
|
129 |
+
transforms.Resize(sample_size[0], antialias=None),
|
130 |
+
transforms.CenterCrop(sample_size),
|
131 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
132 |
+
])
|
133 |
+
|
134 |
+
def get_batch(self, idx):
|
135 |
+
video_dict = self.dataset[idx]
|
136 |
+
video_relative_path, name = video_dict['file'], video_dict['text']
|
137 |
+
fps = video_dict['fps']
|
138 |
+
|
139 |
+
if self.video_folder is not None:
|
140 |
+
if video_relative_path[0] == '/':
|
141 |
+
video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
|
142 |
+
else:
|
143 |
+
video_dir = os.path.join(self.video_folder, video_relative_path)
|
144 |
+
else:
|
145 |
+
video_dir = video_relative_path
|
146 |
+
video_reader = VideoReader(video_dir)
|
147 |
+
video_length = len(video_reader)
|
148 |
+
|
149 |
+
if not self.is_image:
|
150 |
+
clip_length = min(video_length, math.ceil(fps * self.sample_duration))
|
151 |
+
start_idx = random.randint(0, video_length - clip_length)
|
152 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
153 |
+
else:
|
154 |
+
frame_difference = random.randint(2, self.sample_n_frames)
|
155 |
+
sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
|
156 |
+
clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
|
157 |
+
start_idx = random.randint(0, video_length - clip_length)
|
158 |
+
batch_index = [start_idx, start_idx + clip_length - 1]
|
159 |
+
|
160 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
161 |
+
pixel_values = pixel_values / 255.
|
162 |
+
del video_reader
|
163 |
+
|
164 |
+
return pixel_values, name
|
165 |
+
|
166 |
+
def __len__(self):
|
167 |
+
return self.length
|
168 |
+
|
169 |
+
def __getitem__(self, idx):
|
170 |
+
while True:
|
171 |
+
try:
|
172 |
+
pixel_values, name = self.get_batch(idx)
|
173 |
+
break
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
idx = random.randint(0, self.length-1)
|
177 |
+
|
178 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
179 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
180 |
+
return sample
|
181 |
+
|
182 |
+
|
183 |
+
class JointDataset(Dataset):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
webvid_config, pexels_config,
|
187 |
+
sample_size=256,
|
188 |
+
sample_duration=None, sample_fps=None, sample_stride=None, sample_n_frames=None,
|
189 |
+
is_image=False,
|
190 |
+
**kwargs,
|
191 |
+
):
|
192 |
+
assert (sample_duration is None and sample_fps is None) or (sample_duration is not None and sample_fps is not None), "sample_duration and sample_fps should be both None or not None"
|
193 |
+
if sample_duration is not None and sample_fps is not None:
|
194 |
+
assert sample_stride is None, "when sample_duration and sample_fps are not None, sample_stride should be None"
|
195 |
+
if sample_stride is not None:
|
196 |
+
assert sample_fps is None and sample_duration is None, "when sample_stride is not None, sample_duration and sample_fps should be both None"
|
197 |
+
|
198 |
+
self.dataset = []
|
199 |
+
|
200 |
+
if pexels_config.enable:
|
201 |
+
logger.info(f"loading pexels dataset")
|
202 |
+
logger.info(f"loading captions from {pexels_config.caption_json_path} ...")
|
203 |
+
with open(pexels_config.caption_json_path, 'rb') as caption_json_file:
|
204 |
+
caption_json_list = list(caption_json_file)
|
205 |
+
self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
|
206 |
+
|
207 |
+
logger.info(f"loading annotations from {pexels_config.json_path} ...")
|
208 |
+
with open(pexels_config.json_path, 'rb') as json_file:
|
209 |
+
json_list = list(json_file)
|
210 |
+
dataset = [json.loads(json_str) for json_str in json_list]
|
211 |
+
|
212 |
+
for data in dataset:
|
213 |
+
data['text'] = self.caption_dict[data['id']]
|
214 |
+
data['dataset'] = 'pexels'
|
215 |
+
if data['height'] / data['width'] < 0.625:
|
216 |
+
self.dataset.append(data)
|
217 |
+
|
218 |
+
if webvid_config.enable:
|
219 |
+
logger.info(f"loading webvid dataset")
|
220 |
+
logger.info(f"loading annotations from {webvid_config.json_path} ...")
|
221 |
+
with open(webvid_config.json_path, 'rb') as json_file:
|
222 |
+
json_list = list(json_file)
|
223 |
+
dataset = [json.loads(json_str) for json_str in json_list]
|
224 |
+
for data in dataset:
|
225 |
+
data['dataset'] = 'webvid'
|
226 |
+
self.dataset.extend(dataset)
|
227 |
+
|
228 |
+
self.length = len(self.dataset)
|
229 |
+
logger.info(f"data scale: {self.length}")
|
230 |
+
|
231 |
+
self.pexels_folder = pexels_config.video_folder
|
232 |
+
self.webvid_folder = webvid_config.video_folder
|
233 |
+
self.sample_duration = sample_duration
|
234 |
+
self.sample_fps = sample_fps
|
235 |
+
self.sample_n_frames = sample_duration * sample_fps if sample_n_frames is None else sample_n_frames
|
236 |
+
self.sample_stride = sample_stride if (sample_stride is None) or (sample_stride is not None and isinstance(sample_stride, int)) else tuple(sample_stride)
|
237 |
+
self.is_image = is_image
|
238 |
+
|
239 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
240 |
+
self.pixel_transforms = transforms.Compose([
|
241 |
+
transforms.RandomHorizontalFlip(),
|
242 |
+
transforms.Resize(sample_size[0], antialias=None),
|
243 |
+
transforms.CenterCrop(sample_size),
|
244 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
245 |
+
])
|
246 |
+
|
247 |
+
def get_batch(self, idx):
|
248 |
+
video_dict = self.dataset[idx]
|
249 |
+
video_relative_path, name = video_dict['file'], video_dict['text']
|
250 |
+
|
251 |
+
if video_dict['dataset'] == 'pexels':
|
252 |
+
video_folder = self.pexels_folder
|
253 |
+
elif video_dict['dataset'] == 'webvid':
|
254 |
+
video_folder = self.webvid_folder
|
255 |
+
else:
|
256 |
+
raise NotImplementedError
|
257 |
+
|
258 |
+
if video_folder is not None:
|
259 |
+
if video_relative_path[0] == '/':
|
260 |
+
video_dir = os.path.join(video_folder, os.path.basename(video_relative_path))
|
261 |
+
else:
|
262 |
+
video_dir = os.path.join(video_folder, video_relative_path)
|
263 |
+
else:
|
264 |
+
video_dir = video_relative_path
|
265 |
+
video_reader = VideoReader(video_dir)
|
266 |
+
video_length = len(video_reader)
|
267 |
+
|
268 |
+
stride = None
|
269 |
+
if not self.is_image:
|
270 |
+
if self.sample_duration is not None:
|
271 |
+
fps = video_dict['fps']
|
272 |
+
clip_length = min(video_length, math.ceil(fps * self.sample_duration))
|
273 |
+
elif self.sample_stride is not None:
|
274 |
+
if isinstance(self.sample_stride, int):
|
275 |
+
stride = self.sample_stride
|
276 |
+
elif isinstance(self.sample_stride, tuple):
|
277 |
+
stride = random.randint(self.sample_stride[0], self.sample_stride[1])
|
278 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
|
279 |
+
|
280 |
+
start_idx = random.randint(0, video_length - clip_length)
|
281 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
282 |
+
|
283 |
+
else:
|
284 |
+
frame_difference = random.randint(2, self.sample_n_frames)
|
285 |
+
if self.sample_duration is not None:
|
286 |
+
fps = video_dict['fps']
|
287 |
+
sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
|
288 |
+
elif self.sample_stride is not None:
|
289 |
+
sample_stride = self.sample_stride
|
290 |
+
|
291 |
+
clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
|
292 |
+
start_idx = random.randint(0, video_length - clip_length)
|
293 |
+
batch_index = [start_idx, start_idx + clip_length - 1]
|
294 |
+
|
295 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
296 |
+
pixel_values = pixel_values / 255.
|
297 |
+
del video_reader
|
298 |
+
|
299 |
+
return pixel_values, name, stride
|
300 |
+
|
301 |
+
def __len__(self):
|
302 |
+
return self.length
|
303 |
+
|
304 |
+
def __getitem__(self, idx):
|
305 |
+
while True:
|
306 |
+
try:
|
307 |
+
pixel_values, name, stride = self.get_batch(idx)
|
308 |
+
break
|
309 |
+
|
310 |
+
except Exception as e:
|
311 |
+
idx = random.randint(0, self.length-1)
|
312 |
+
|
313 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
314 |
+
sample = dict(pixel_values=pixel_values, text=name, stride=stride)
|
315 |
+
return sample
|
src/videogen_hub/pipelines/consisti2v/consisti2v/models/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/consisti2v/models/rotary_embedding.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi, log
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import Module, ModuleList
|
5 |
+
from torch.cuda.amp import autocast
|
6 |
+
from torch import nn, einsum, broadcast_tensors, Tensor
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
from beartype import beartype
|
11 |
+
from beartype.typing import Literal, Union, Optional
|
12 |
+
|
13 |
+
# helper functions
|
14 |
+
|
15 |
+
def exists(val):
|
16 |
+
return val is not None
|
17 |
+
|
18 |
+
def default(val, d):
|
19 |
+
return val if exists(val) else d
|
20 |
+
|
21 |
+
# broadcat, as tortoise-tts was using it
|
22 |
+
|
23 |
+
def broadcat(tensors, dim = -1):
|
24 |
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
25 |
+
return torch.cat(broadcasted_tensors, dim = dim)
|
26 |
+
|
27 |
+
# rotary embedding helper functions
|
28 |
+
|
29 |
+
def rotate_half(x):
|
30 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
31 |
+
x1, x2 = x.unbind(dim = -1)
|
32 |
+
x = torch.stack((-x2, x1), dim = -1)
|
33 |
+
return rearrange(x, '... d r -> ... (d r)')
|
34 |
+
|
35 |
+
@autocast(enabled = False)
|
36 |
+
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
|
37 |
+
if t.ndim == 3:
|
38 |
+
seq_len = t.shape[seq_dim]
|
39 |
+
freqs = freqs[-seq_len:].to(t)
|
40 |
+
|
41 |
+
rot_dim = freqs.shape[-1]
|
42 |
+
end_index = start_index + rot_dim
|
43 |
+
|
44 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
45 |
+
|
46 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
47 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
48 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
49 |
+
|
50 |
+
# learned rotation helpers
|
51 |
+
|
52 |
+
def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
|
53 |
+
if exists(freq_ranges):
|
54 |
+
rotations = einsum('..., f -> ... f', rotations, freq_ranges)
|
55 |
+
rotations = rearrange(rotations, '... r f -> ... (r f)')
|
56 |
+
|
57 |
+
rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
|
58 |
+
return apply_rotary_emb(rotations, t, start_index = start_index)
|
59 |
+
|
60 |
+
# classes
|
61 |
+
|
62 |
+
class RotaryEmbedding(Module):
|
63 |
+
@beartype
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
dim,
|
67 |
+
custom_freqs: Optional[Tensor] = None,
|
68 |
+
freqs_for: Union[
|
69 |
+
Literal['lang'],
|
70 |
+
Literal['pixel'],
|
71 |
+
Literal['constant']
|
72 |
+
] = 'lang',
|
73 |
+
theta = 10000,
|
74 |
+
max_freq = 10,
|
75 |
+
num_freqs = 1,
|
76 |
+
learned_freq = False,
|
77 |
+
use_xpos = False,
|
78 |
+
xpos_scale_base = 512,
|
79 |
+
interpolate_factor = 1.,
|
80 |
+
theta_rescale_factor = 1.,
|
81 |
+
seq_before_head_dim = False
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
85 |
+
# has some connection to NTK literature
|
86 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
87 |
+
|
88 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
89 |
+
|
90 |
+
self.freqs_for = freqs_for
|
91 |
+
|
92 |
+
if exists(custom_freqs):
|
93 |
+
freqs = custom_freqs
|
94 |
+
elif freqs_for == 'lang':
|
95 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
96 |
+
elif freqs_for == 'pixel':
|
97 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
98 |
+
elif freqs_for == 'constant':
|
99 |
+
freqs = torch.ones(num_freqs).float()
|
100 |
+
|
101 |
+
self.tmp_store('cached_freqs', None)
|
102 |
+
self.tmp_store('cached_scales', None)
|
103 |
+
|
104 |
+
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
|
105 |
+
|
106 |
+
self.learned_freq = learned_freq
|
107 |
+
|
108 |
+
# dummy for device
|
109 |
+
|
110 |
+
self.tmp_store('dummy', torch.tensor(0))
|
111 |
+
|
112 |
+
# default sequence dimension
|
113 |
+
|
114 |
+
self.seq_before_head_dim = seq_before_head_dim
|
115 |
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
116 |
+
|
117 |
+
# interpolation factors
|
118 |
+
|
119 |
+
assert interpolate_factor >= 1.
|
120 |
+
self.interpolate_factor = interpolate_factor
|
121 |
+
|
122 |
+
# xpos
|
123 |
+
|
124 |
+
self.use_xpos = use_xpos
|
125 |
+
if not use_xpos:
|
126 |
+
self.tmp_store('scale', None)
|
127 |
+
return
|
128 |
+
|
129 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
130 |
+
self.scale_base = xpos_scale_base
|
131 |
+
self.tmp_store('scale', scale)
|
132 |
+
|
133 |
+
@property
|
134 |
+
def device(self):
|
135 |
+
return self.dummy.device
|
136 |
+
|
137 |
+
def tmp_store(self, key, value):
|
138 |
+
self.register_buffer(key, value, persistent = False)
|
139 |
+
|
140 |
+
def get_seq_pos(self, seq_len, device, dtype, offset = 0):
|
141 |
+
return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
|
142 |
+
|
143 |
+
def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None, seq_pos = None):
|
144 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
145 |
+
|
146 |
+
assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
|
147 |
+
|
148 |
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
149 |
+
|
150 |
+
if exists(freq_seq_len):
|
151 |
+
assert freq_seq_len >= seq_len
|
152 |
+
seq_len = freq_seq_len
|
153 |
+
|
154 |
+
if seq_pos is None:
|
155 |
+
seq_pos = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset)
|
156 |
+
else:
|
157 |
+
assert seq_pos.shape[0] == seq_len
|
158 |
+
|
159 |
+
freqs = self.forward(seq_pos, seq_len = seq_len, offset = offset)
|
160 |
+
|
161 |
+
if seq_dim == -3:
|
162 |
+
freqs = rearrange(freqs, 'n d -> n 1 d')
|
163 |
+
|
164 |
+
return apply_rotary_emb(freqs, t, seq_dim = seq_dim)
|
165 |
+
|
166 |
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
|
167 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
168 |
+
|
169 |
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
170 |
+
assert q_len <= k_len
|
171 |
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len)
|
172 |
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim)
|
173 |
+
|
174 |
+
rotated_q = rotated_q.type(q.dtype)
|
175 |
+
rotated_k = rotated_k.type(k.dtype)
|
176 |
+
|
177 |
+
return rotated_q, rotated_k
|
178 |
+
|
179 |
+
def rotate_queries_and_keys(self, q, k, seq_dim = None):
|
180 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
181 |
+
|
182 |
+
assert self.use_xpos
|
183 |
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
184 |
+
|
185 |
+
seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
|
186 |
+
|
187 |
+
freqs = self.forward(seq, seq_len = seq_len)
|
188 |
+
scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
|
189 |
+
|
190 |
+
if seq_dim == -3:
|
191 |
+
freqs = rearrange(freqs, 'n d -> n 1 d')
|
192 |
+
scale = rearrange(scale, 'n d -> n 1 d')
|
193 |
+
|
194 |
+
rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
|
195 |
+
rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
|
196 |
+
|
197 |
+
rotated_q = rotated_q.type(q.dtype)
|
198 |
+
rotated_k = rotated_k.type(k.dtype)
|
199 |
+
|
200 |
+
return rotated_q, rotated_k
|
201 |
+
|
202 |
+
@beartype
|
203 |
+
def get_scale(
|
204 |
+
self,
|
205 |
+
t: Tensor,
|
206 |
+
seq_len: Optional[int] = None,
|
207 |
+
offset = 0
|
208 |
+
):
|
209 |
+
assert self.use_xpos
|
210 |
+
|
211 |
+
should_cache = exists(seq_len)
|
212 |
+
|
213 |
+
if (
|
214 |
+
should_cache and \
|
215 |
+
exists(self.cached_scales) and \
|
216 |
+
(seq_len + offset) <= self.cached_scales.shape[0]
|
217 |
+
):
|
218 |
+
return self.cached_scales[offset:(offset + seq_len)]
|
219 |
+
|
220 |
+
scale = 1.
|
221 |
+
if self.use_xpos:
|
222 |
+
power = (t - len(t) // 2) / self.scale_base
|
223 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
224 |
+
scale = torch.cat((scale, scale), dim = -1)
|
225 |
+
|
226 |
+
if should_cache:
|
227 |
+
self.tmp_store('cached_scales', scale)
|
228 |
+
|
229 |
+
return scale
|
230 |
+
|
231 |
+
def get_axial_freqs(self, *dims):
|
232 |
+
Colon = slice(None)
|
233 |
+
all_freqs = []
|
234 |
+
|
235 |
+
for ind, dim in enumerate(dims):
|
236 |
+
if self.freqs_for == 'pixel':
|
237 |
+
pos = torch.linspace(-1, 1, steps = dim, device = self.device)
|
238 |
+
else:
|
239 |
+
pos = torch.arange(dim, device = self.device)
|
240 |
+
|
241 |
+
freqs = self.forward(pos, seq_len = dim)
|
242 |
+
|
243 |
+
all_axis = [None] * len(dims)
|
244 |
+
all_axis[ind] = Colon
|
245 |
+
|
246 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
247 |
+
all_freqs.append(freqs[new_axis_slice])
|
248 |
+
|
249 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
250 |
+
return torch.cat(all_freqs, dim = -1)
|
251 |
+
|
252 |
+
@autocast(enabled = False)
|
253 |
+
def forward(
|
254 |
+
self,
|
255 |
+
t: Tensor,
|
256 |
+
seq_len = None,
|
257 |
+
offset = 0
|
258 |
+
):
|
259 |
+
# should_cache = (
|
260 |
+
# not self.learned_freq and \
|
261 |
+
# exists(seq_len) and \
|
262 |
+
# self.freqs_for != 'pixel'
|
263 |
+
# )
|
264 |
+
|
265 |
+
# if (
|
266 |
+
# should_cache and \
|
267 |
+
# exists(self.cached_freqs) and \
|
268 |
+
# (offset + seq_len) <= self.cached_freqs.shape[0]
|
269 |
+
# ):
|
270 |
+
# return self.cached_freqs[offset:(offset + seq_len)].detach()
|
271 |
+
|
272 |
+
freqs = self.freqs
|
273 |
+
|
274 |
+
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
|
275 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
276 |
+
|
277 |
+
# if should_cache:
|
278 |
+
# self.tmp_store('cached_freqs', freqs.detach())
|
279 |
+
|
280 |
+
return freqs
|
src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_attention.py
ADDED
@@ -0,0 +1,809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
from typing import Callable, Optional, Union
|
3 |
+
import math
|
4 |
+
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from diffusers.utils import deprecate, logging
|
12 |
+
from diffusers.utils.import_utils import is_xformers_available
|
13 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
14 |
+
from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
|
15 |
+
from diffusers.models.attention_processor import (
|
16 |
+
Attention,
|
17 |
+
AttnAddedKVProcessor,
|
18 |
+
AttnAddedKVProcessor2_0,
|
19 |
+
AttnProcessor,
|
20 |
+
AttnProcessor2_0,
|
21 |
+
SpatialNorm,
|
22 |
+
LORA_ATTENTION_PROCESSORS,
|
23 |
+
CustomDiffusionAttnProcessor,
|
24 |
+
CustomDiffusionXFormersAttnProcessor,
|
25 |
+
SlicedAttnAddedKVProcessor,
|
26 |
+
XFormersAttnAddedKVProcessor,
|
27 |
+
LoRAAttnAddedKVProcessor,
|
28 |
+
XFormersAttnProcessor,
|
29 |
+
LoRAXFormersAttnProcessor,
|
30 |
+
LoRAAttnProcessor,
|
31 |
+
LoRAAttnProcessor2_0,
|
32 |
+
SlicedAttnProcessor,
|
33 |
+
AttentionProcessor
|
34 |
+
)
|
35 |
+
|
36 |
+
from .rotary_embedding import RotaryEmbedding
|
37 |
+
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40 |
+
|
41 |
+
|
42 |
+
if is_xformers_available():
|
43 |
+
import xformers
|
44 |
+
import xformers.ops
|
45 |
+
else:
|
46 |
+
xformers = None
|
47 |
+
|
48 |
+
@maybe_allow_in_graph
|
49 |
+
class ConditionalAttention(nn.Module):
|
50 |
+
r"""
|
51 |
+
A cross attention layer.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
query_dim (`int`): The number of channels in the query.
|
55 |
+
cross_attention_dim (`int`, *optional*):
|
56 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
57 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
58 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
59 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
60 |
+
bias (`bool`, *optional*, defaults to False):
|
61 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
query_dim: int,
|
67 |
+
cross_attention_dim: Optional[int] = None,
|
68 |
+
heads: int = 8,
|
69 |
+
dim_head: int = 64,
|
70 |
+
dropout: float = 0.0,
|
71 |
+
bias=False,
|
72 |
+
upcast_attention: bool = False,
|
73 |
+
upcast_softmax: bool = False,
|
74 |
+
cross_attention_norm: Optional[str] = None,
|
75 |
+
cross_attention_norm_num_groups: int = 32,
|
76 |
+
added_kv_proj_dim: Optional[int] = None,
|
77 |
+
norm_num_groups: Optional[int] = None,
|
78 |
+
spatial_norm_dim: Optional[int] = None,
|
79 |
+
out_bias: bool = True,
|
80 |
+
scale_qk: bool = True,
|
81 |
+
only_cross_attention: bool = False,
|
82 |
+
eps: float = 1e-5,
|
83 |
+
rescale_output_factor: float = 1.0,
|
84 |
+
residual_connection: bool = False,
|
85 |
+
_from_deprecated_attn_block=False,
|
86 |
+
processor: Optional["AttnProcessor"] = None,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
self.inner_dim = dim_head * heads
|
90 |
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
91 |
+
self.upcast_attention = upcast_attention
|
92 |
+
self.upcast_softmax = upcast_softmax
|
93 |
+
self.rescale_output_factor = rescale_output_factor
|
94 |
+
self.residual_connection = residual_connection
|
95 |
+
self.dropout = dropout
|
96 |
+
|
97 |
+
# we make use of this private variable to know whether this class is loaded
|
98 |
+
# with an deprecated state dict so that we can convert it on the fly
|
99 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
100 |
+
|
101 |
+
self.scale_qk = scale_qk
|
102 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
103 |
+
|
104 |
+
self.heads = heads
|
105 |
+
# for slice_size > 0 the attention score computation
|
106 |
+
# is split across the batch axis to save memory
|
107 |
+
# You can set slice_size with `set_attention_slice`
|
108 |
+
self.sliceable_head_dim = heads
|
109 |
+
|
110 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
111 |
+
self.only_cross_attention = only_cross_attention
|
112 |
+
|
113 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
114 |
+
raise ValueError(
|
115 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
116 |
+
)
|
117 |
+
|
118 |
+
if norm_num_groups is not None:
|
119 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
120 |
+
else:
|
121 |
+
self.group_norm = None
|
122 |
+
|
123 |
+
if spatial_norm_dim is not None:
|
124 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
125 |
+
else:
|
126 |
+
self.spatial_norm = None
|
127 |
+
|
128 |
+
if cross_attention_norm is None:
|
129 |
+
self.norm_cross = None
|
130 |
+
elif cross_attention_norm == "layer_norm":
|
131 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
132 |
+
elif cross_attention_norm == "group_norm":
|
133 |
+
if self.added_kv_proj_dim is not None:
|
134 |
+
# The given `encoder_hidden_states` are initially of shape
|
135 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
136 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
137 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
138 |
+
# the number of channels for the group norm.
|
139 |
+
norm_cross_num_channels = added_kv_proj_dim
|
140 |
+
else:
|
141 |
+
norm_cross_num_channels = self.cross_attention_dim
|
142 |
+
|
143 |
+
self.norm_cross = nn.GroupNorm(
|
144 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
raise ValueError(
|
148 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
149 |
+
)
|
150 |
+
|
151 |
+
self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias)
|
152 |
+
|
153 |
+
if not self.only_cross_attention:
|
154 |
+
# only relevant for the `AddedKVProcessor` classes
|
155 |
+
self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
156 |
+
self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
157 |
+
else:
|
158 |
+
self.to_k = None
|
159 |
+
self.to_v = None
|
160 |
+
|
161 |
+
if self.added_kv_proj_dim is not None:
|
162 |
+
self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
|
163 |
+
self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
|
164 |
+
|
165 |
+
self.to_out = nn.ModuleList([])
|
166 |
+
self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias))
|
167 |
+
self.to_out.append(nn.Dropout(dropout))
|
168 |
+
|
169 |
+
# set attention processor
|
170 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
171 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
172 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
173 |
+
if processor is None:
|
174 |
+
processor = (
|
175 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
176 |
+
)
|
177 |
+
self.set_processor(processor)
|
178 |
+
|
179 |
+
def set_use_memory_efficient_attention_xformers(
|
180 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
181 |
+
):
|
182 |
+
is_lora = hasattr(self, "processor") and isinstance(
|
183 |
+
self.processor,
|
184 |
+
LORA_ATTENTION_PROCESSORS,
|
185 |
+
)
|
186 |
+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
187 |
+
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
188 |
+
)
|
189 |
+
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
190 |
+
self.processor,
|
191 |
+
(
|
192 |
+
AttnAddedKVProcessor,
|
193 |
+
AttnAddedKVProcessor2_0,
|
194 |
+
SlicedAttnAddedKVProcessor,
|
195 |
+
XFormersAttnAddedKVProcessor,
|
196 |
+
LoRAAttnAddedKVProcessor,
|
197 |
+
),
|
198 |
+
)
|
199 |
+
|
200 |
+
if use_memory_efficient_attention_xformers:
|
201 |
+
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
202 |
+
raise NotImplementedError(
|
203 |
+
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
|
204 |
+
)
|
205 |
+
if not is_xformers_available():
|
206 |
+
raise ModuleNotFoundError(
|
207 |
+
(
|
208 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
209 |
+
" xformers"
|
210 |
+
),
|
211 |
+
name="xformers",
|
212 |
+
)
|
213 |
+
elif not torch.cuda.is_available():
|
214 |
+
raise ValueError(
|
215 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
216 |
+
" only available for GPU "
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
try:
|
220 |
+
# Make sure we can run the memory efficient attention
|
221 |
+
_ = xformers.ops.memory_efficient_attention(
|
222 |
+
torch.randn((1, 2, 40), device="cuda"),
|
223 |
+
torch.randn((1, 2, 40), device="cuda"),
|
224 |
+
torch.randn((1, 2, 40), device="cuda"),
|
225 |
+
)
|
226 |
+
except Exception as e:
|
227 |
+
raise e
|
228 |
+
|
229 |
+
if is_lora:
|
230 |
+
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
231 |
+
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
232 |
+
processor = LoRAXFormersAttnProcessor(
|
233 |
+
hidden_size=self.processor.hidden_size,
|
234 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
235 |
+
rank=self.processor.rank,
|
236 |
+
attention_op=attention_op,
|
237 |
+
)
|
238 |
+
processor.load_state_dict(self.processor.state_dict())
|
239 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
240 |
+
elif is_custom_diffusion:
|
241 |
+
processor = CustomDiffusionXFormersAttnProcessor(
|
242 |
+
train_kv=self.processor.train_kv,
|
243 |
+
train_q_out=self.processor.train_q_out,
|
244 |
+
hidden_size=self.processor.hidden_size,
|
245 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
246 |
+
attention_op=attention_op,
|
247 |
+
)
|
248 |
+
processor.load_state_dict(self.processor.state_dict())
|
249 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
250 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
251 |
+
elif is_added_kv_processor:
|
252 |
+
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
253 |
+
# which uses this type of cross attention ONLY because the attention mask of format
|
254 |
+
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
255 |
+
# throw warning
|
256 |
+
logger.info(
|
257 |
+
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
258 |
+
)
|
259 |
+
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
260 |
+
else:
|
261 |
+
processor = XFormersAttnProcessor(attention_op=attention_op)
|
262 |
+
else:
|
263 |
+
if is_lora:
|
264 |
+
attn_processor_class = (
|
265 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
266 |
+
)
|
267 |
+
processor = attn_processor_class(
|
268 |
+
hidden_size=self.processor.hidden_size,
|
269 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
270 |
+
rank=self.processor.rank,
|
271 |
+
)
|
272 |
+
processor.load_state_dict(self.processor.state_dict())
|
273 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
274 |
+
elif is_custom_diffusion:
|
275 |
+
processor = CustomDiffusionAttnProcessor(
|
276 |
+
train_kv=self.processor.train_kv,
|
277 |
+
train_q_out=self.processor.train_q_out,
|
278 |
+
hidden_size=self.processor.hidden_size,
|
279 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
280 |
+
)
|
281 |
+
processor.load_state_dict(self.processor.state_dict())
|
282 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
283 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
284 |
+
else:
|
285 |
+
# set attention processor
|
286 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
287 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
288 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
289 |
+
processor = (
|
290 |
+
AttnProcessor2_0()
|
291 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
292 |
+
else AttnProcessor()
|
293 |
+
)
|
294 |
+
|
295 |
+
self.set_processor(processor)
|
296 |
+
|
297 |
+
def set_attention_slice(self, slice_size):
|
298 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
299 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
300 |
+
|
301 |
+
if slice_size is not None and self.added_kv_proj_dim is not None:
|
302 |
+
processor = SlicedAttnAddedKVProcessor(slice_size)
|
303 |
+
elif slice_size is not None:
|
304 |
+
processor = SlicedAttnProcessor(slice_size)
|
305 |
+
elif self.added_kv_proj_dim is not None:
|
306 |
+
processor = AttnAddedKVProcessor()
|
307 |
+
else:
|
308 |
+
# set attention processor
|
309 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
310 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
311 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
312 |
+
processor = (
|
313 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
314 |
+
)
|
315 |
+
|
316 |
+
self.set_processor(processor)
|
317 |
+
|
318 |
+
def set_processor(self, processor: "AttnProcessor"):
|
319 |
+
if (
|
320 |
+
hasattr(self, "processor")
|
321 |
+
and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
|
322 |
+
and self.to_q.lora_layer is not None
|
323 |
+
):
|
324 |
+
deprecate(
|
325 |
+
"set_processor to offload LoRA",
|
326 |
+
"0.26.0",
|
327 |
+
"In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
|
328 |
+
)
|
329 |
+
# (Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
|
330 |
+
# We need to remove all LoRA layers
|
331 |
+
for module in self.modules():
|
332 |
+
if hasattr(module, "set_lora_layer"):
|
333 |
+
module.set_lora_layer(None)
|
334 |
+
|
335 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
336 |
+
# pop `processor` from `self._modules`
|
337 |
+
if (
|
338 |
+
hasattr(self, "processor")
|
339 |
+
and isinstance(self.processor, torch.nn.Module)
|
340 |
+
and not isinstance(processor, torch.nn.Module)
|
341 |
+
):
|
342 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
343 |
+
self._modules.pop("processor")
|
344 |
+
|
345 |
+
self.processor = processor
|
346 |
+
|
347 |
+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
348 |
+
if not return_deprecated_lora:
|
349 |
+
return self.processor
|
350 |
+
|
351 |
+
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
|
352 |
+
# serialization format for LoRA Attention Processors. It should be deleted once the integration
|
353 |
+
# with PEFT is completed.
|
354 |
+
is_lora_activated = {
|
355 |
+
name: module.lora_layer is not None
|
356 |
+
for name, module in self.named_modules()
|
357 |
+
if hasattr(module, "lora_layer")
|
358 |
+
}
|
359 |
+
|
360 |
+
# 1. if no layer has a LoRA activated we can return the processor as usual
|
361 |
+
if not any(is_lora_activated.values()):
|
362 |
+
return self.processor
|
363 |
+
|
364 |
+
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
|
365 |
+
is_lora_activated.pop("add_k_proj", None)
|
366 |
+
is_lora_activated.pop("add_v_proj", None)
|
367 |
+
# 2. else it is not posssible that only some layers have LoRA activated
|
368 |
+
if not all(is_lora_activated.values()):
|
369 |
+
raise ValueError(
|
370 |
+
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
|
371 |
+
)
|
372 |
+
|
373 |
+
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
374 |
+
non_lora_processor_cls_name = self.processor.__class__.__name__
|
375 |
+
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
|
376 |
+
|
377 |
+
hidden_size = self.inner_dim
|
378 |
+
|
379 |
+
# now create a LoRA attention processor from the LoRA layers
|
380 |
+
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
|
381 |
+
kwargs = {
|
382 |
+
"cross_attention_dim": self.cross_attention_dim,
|
383 |
+
"rank": self.to_q.lora_layer.rank,
|
384 |
+
"network_alpha": self.to_q.lora_layer.network_alpha,
|
385 |
+
"q_rank": self.to_q.lora_layer.rank,
|
386 |
+
"q_hidden_size": self.to_q.lora_layer.out_features,
|
387 |
+
"k_rank": self.to_k.lora_layer.rank,
|
388 |
+
"k_hidden_size": self.to_k.lora_layer.out_features,
|
389 |
+
"v_rank": self.to_v.lora_layer.rank,
|
390 |
+
"v_hidden_size": self.to_v.lora_layer.out_features,
|
391 |
+
"out_rank": self.to_out[0].lora_layer.rank,
|
392 |
+
"out_hidden_size": self.to_out[0].lora_layer.out_features,
|
393 |
+
}
|
394 |
+
|
395 |
+
if hasattr(self.processor, "attention_op"):
|
396 |
+
kwargs["attention_op"] = self.prcoessor.attention_op
|
397 |
+
|
398 |
+
lora_processor = lora_processor_cls(hidden_size, **kwargs)
|
399 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
400 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
401 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
402 |
+
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
403 |
+
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
404 |
+
lora_processor = lora_processor_cls(
|
405 |
+
hidden_size,
|
406 |
+
cross_attention_dim=self.add_k_proj.weight.shape[0],
|
407 |
+
rank=self.to_q.lora_layer.rank,
|
408 |
+
network_alpha=self.to_q.lora_layer.network_alpha,
|
409 |
+
)
|
410 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
411 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
412 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
413 |
+
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
414 |
+
|
415 |
+
# only save if used
|
416 |
+
if self.add_k_proj.lora_layer is not None:
|
417 |
+
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
|
418 |
+
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
|
419 |
+
else:
|
420 |
+
lora_processor.add_k_proj_lora = None
|
421 |
+
lora_processor.add_v_proj_lora = None
|
422 |
+
else:
|
423 |
+
raise ValueError(f"{lora_processor_cls} does not exist.")
|
424 |
+
|
425 |
+
return lora_processor
|
426 |
+
|
427 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
428 |
+
# The `Attention` class can call different attention processors / attention functions
|
429 |
+
# here we simply pass along all tensors to the selected processor class
|
430 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
431 |
+
return self.processor(
|
432 |
+
self,
|
433 |
+
hidden_states,
|
434 |
+
encoder_hidden_states=encoder_hidden_states,
|
435 |
+
attention_mask=attention_mask,
|
436 |
+
**cross_attention_kwargs,
|
437 |
+
)
|
438 |
+
|
439 |
+
def batch_to_head_dim(self, tensor):
|
440 |
+
head_size = self.heads
|
441 |
+
batch_size, seq_len, dim = tensor.shape
|
442 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
443 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
444 |
+
return tensor
|
445 |
+
|
446 |
+
def head_to_batch_dim(self, tensor, out_dim=3):
|
447 |
+
head_size = self.heads
|
448 |
+
batch_size, seq_len, dim = tensor.shape
|
449 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
450 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
451 |
+
|
452 |
+
if out_dim == 3:
|
453 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
454 |
+
|
455 |
+
return tensor
|
456 |
+
|
457 |
+
def get_attention_scores(self, query, key, attention_mask=None):
|
458 |
+
dtype = query.dtype
|
459 |
+
if self.upcast_attention:
|
460 |
+
query = query.float()
|
461 |
+
key = key.float()
|
462 |
+
|
463 |
+
if attention_mask is None:
|
464 |
+
baddbmm_input = torch.empty(
|
465 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
466 |
+
)
|
467 |
+
beta = 0
|
468 |
+
else:
|
469 |
+
baddbmm_input = attention_mask
|
470 |
+
beta = 1
|
471 |
+
|
472 |
+
attention_scores = torch.baddbmm(
|
473 |
+
baddbmm_input,
|
474 |
+
query,
|
475 |
+
key.transpose(-1, -2),
|
476 |
+
beta=beta,
|
477 |
+
alpha=self.scale,
|
478 |
+
)
|
479 |
+
del baddbmm_input
|
480 |
+
|
481 |
+
if self.upcast_softmax:
|
482 |
+
attention_scores = attention_scores.float()
|
483 |
+
|
484 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
485 |
+
del attention_scores
|
486 |
+
|
487 |
+
attention_probs = attention_probs.to(dtype)
|
488 |
+
|
489 |
+
return attention_probs
|
490 |
+
|
491 |
+
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
492 |
+
if batch_size is None:
|
493 |
+
deprecate(
|
494 |
+
"batch_size=None",
|
495 |
+
"0.22.0",
|
496 |
+
(
|
497 |
+
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
498 |
+
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
499 |
+
" `prepare_attention_mask` when preparing the attention_mask."
|
500 |
+
),
|
501 |
+
)
|
502 |
+
batch_size = 1
|
503 |
+
|
504 |
+
head_size = self.heads
|
505 |
+
if attention_mask is None:
|
506 |
+
return attention_mask
|
507 |
+
|
508 |
+
current_length: int = attention_mask.shape[-1]
|
509 |
+
if current_length != target_length:
|
510 |
+
if attention_mask.device.type == "mps":
|
511 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
512 |
+
# Instead, we can manually construct the padding tensor.
|
513 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
514 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
515 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
516 |
+
else:
|
517 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
518 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
519 |
+
# remaining_length: int = target_length - current_length
|
520 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
521 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
522 |
+
|
523 |
+
if out_dim == 3:
|
524 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
525 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
526 |
+
elif out_dim == 4:
|
527 |
+
attention_mask = attention_mask.unsqueeze(1)
|
528 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
529 |
+
|
530 |
+
return attention_mask
|
531 |
+
|
532 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
533 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
534 |
+
|
535 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
536 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
537 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
538 |
+
# Group norm norms along the channels dimension and expects
|
539 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
540 |
+
# to norm along the hidden dimension, so we need to move
|
541 |
+
# (batch_size, sequence_length, hidden_size) ->
|
542 |
+
# (batch_size, hidden_size, sequence_length)
|
543 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
544 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
545 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
546 |
+
else:
|
547 |
+
assert False
|
548 |
+
|
549 |
+
return encoder_hidden_states
|
550 |
+
|
551 |
+
|
552 |
+
class TemporalConditionalAttention(Attention):
|
553 |
+
def __init__(self, n_frames=8, rotary_emb=False, *args, **kwargs):
|
554 |
+
super().__init__(processor=RotaryEmbAttnProcessor2_0() if rotary_emb else None, *args, **kwargs)
|
555 |
+
|
556 |
+
if not rotary_emb:
|
557 |
+
self.pos_enc = PositionalEncoding(self.inner_dim)
|
558 |
+
else:
|
559 |
+
rotary_bias = RelativePositionBias(heads=kwargs['heads'], max_distance=32)
|
560 |
+
self.rotary_bias = rotary_bias
|
561 |
+
self.rotary_emb = RotaryEmbedding(self.inner_dim // 2)
|
562 |
+
|
563 |
+
self.use_rotary_emb = rotary_emb
|
564 |
+
self.n_frames = n_frames
|
565 |
+
|
566 |
+
def forward(
|
567 |
+
self,
|
568 |
+
hidden_states,
|
569 |
+
encoder_hidden_states=None,
|
570 |
+
attention_mask=None,
|
571 |
+
adjacent_slices=None,
|
572 |
+
**cross_attention_kwargs):
|
573 |
+
|
574 |
+
key_pos_idx = None
|
575 |
+
|
576 |
+
bt, hw, c = hidden_states.shape
|
577 |
+
hidden_states = rearrange(hidden_states, '(b t) hw c -> b hw t c', t=self.n_frames)
|
578 |
+
if not self.use_rotary_emb:
|
579 |
+
pos_embed = self.pos_enc(self.n_frames)
|
580 |
+
hidden_states = hidden_states + pos_embed
|
581 |
+
hidden_states = rearrange(hidden_states, 'b hw t c -> (b hw) t c')
|
582 |
+
|
583 |
+
if encoder_hidden_states is not None:
|
584 |
+
assert adjacent_slices is None
|
585 |
+
encoder_hidden_states = encoder_hidden_states[::self.n_frames]
|
586 |
+
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b hw) n c', hw=hw)
|
587 |
+
|
588 |
+
if adjacent_slices is not None:
|
589 |
+
assert encoder_hidden_states is None
|
590 |
+
adjacent_slices = rearrange(adjacent_slices, 'b c h w n -> b (h w) n c')
|
591 |
+
if not self.use_rotary_emb:
|
592 |
+
first_frame_pos_embed = pos_embed[0:1, :]
|
593 |
+
adjacent_slices = adjacent_slices + first_frame_pos_embed
|
594 |
+
else:
|
595 |
+
pos_idx = torch.arange(self.n_frames, device=hidden_states.device, dtype=hidden_states.dtype)
|
596 |
+
first_frame_pos_pad = torch.zeros(adjacent_slices.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
|
597 |
+
key_pos_idx = torch.cat([pos_idx, first_frame_pos_pad], dim=0)
|
598 |
+
adjacent_slices = rearrange(adjacent_slices, 'b hw n c -> (b hw) n c')
|
599 |
+
encoder_hidden_states = torch.cat([hidden_states, adjacent_slices], dim=1)
|
600 |
+
|
601 |
+
if not self.use_rotary_emb:
|
602 |
+
out = self.processor(
|
603 |
+
self,
|
604 |
+
hidden_states,
|
605 |
+
encoder_hidden_states=encoder_hidden_states,
|
606 |
+
attention_mask=attention_mask,
|
607 |
+
**cross_attention_kwargs,
|
608 |
+
)
|
609 |
+
else:
|
610 |
+
out = self.processor(
|
611 |
+
self,
|
612 |
+
hidden_states,
|
613 |
+
encoder_hidden_states=encoder_hidden_states,
|
614 |
+
attention_mask=attention_mask,
|
615 |
+
key_pos_idx=key_pos_idx,
|
616 |
+
**cross_attention_kwargs,
|
617 |
+
)
|
618 |
+
|
619 |
+
out = rearrange(out, '(b hw) t c -> (b t) hw c', hw=hw)
|
620 |
+
|
621 |
+
return out
|
622 |
+
|
623 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers, attention_op=None):
|
624 |
+
if use_memory_efficient_attention_xformers:
|
625 |
+
try:
|
626 |
+
# Make sure we can run the memory efficient attention
|
627 |
+
_ = xformers.ops.memory_efficient_attention(
|
628 |
+
torch.randn((1, 2, 40), device="cuda"),
|
629 |
+
torch.randn((1, 2, 40), device="cuda"),
|
630 |
+
torch.randn((1, 2, 40), device="cuda"),
|
631 |
+
)
|
632 |
+
except Exception as e:
|
633 |
+
raise e
|
634 |
+
processor = XFormersAttnProcessor(attention_op=attention_op)
|
635 |
+
else:
|
636 |
+
processor = (
|
637 |
+
AttnProcessor2_0()
|
638 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
639 |
+
else AttnProcessor()
|
640 |
+
)
|
641 |
+
self.set_processor(processor)
|
642 |
+
|
643 |
+
|
644 |
+
class PositionalEncoding(nn.Module):
|
645 |
+
def __init__(self, dim, max_pos=512):
|
646 |
+
super().__init__()
|
647 |
+
|
648 |
+
pos = torch.arange(max_pos)
|
649 |
+
|
650 |
+
freq = torch.arange(dim//2) / dim
|
651 |
+
freq = (freq * torch.tensor(10000).log()).exp()
|
652 |
+
|
653 |
+
x = rearrange(pos, 'L -> L 1') / freq
|
654 |
+
x = rearrange(x, 'L d -> L d 1')
|
655 |
+
|
656 |
+
pe = torch.cat((x.sin(), x.cos()), dim=-1)
|
657 |
+
self.pe = rearrange(pe, 'L d sc -> L (d sc)')
|
658 |
+
|
659 |
+
self.dummy = nn.Parameter(torch.rand(1))
|
660 |
+
|
661 |
+
def forward(self, length):
|
662 |
+
enc = self.pe[:length]
|
663 |
+
enc = enc.to(self.dummy.device, self.dummy.dtype)
|
664 |
+
return enc
|
665 |
+
|
666 |
+
|
667 |
+
# code taken from https://github.com/Vchitect/LaVie/blob/main/base/models/temporal_attention.py
|
668 |
+
class RelativePositionBias(nn.Module):
|
669 |
+
def __init__(
|
670 |
+
self,
|
671 |
+
heads=8,
|
672 |
+
num_buckets=32,
|
673 |
+
max_distance=128,
|
674 |
+
):
|
675 |
+
super().__init__()
|
676 |
+
self.num_buckets = num_buckets
|
677 |
+
self.max_distance = max_distance
|
678 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
679 |
+
|
680 |
+
@staticmethod
|
681 |
+
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
|
682 |
+
ret = 0
|
683 |
+
n = -relative_position
|
684 |
+
|
685 |
+
num_buckets //= 2
|
686 |
+
ret += (n < 0).long() * num_buckets
|
687 |
+
n = torch.abs(n)
|
688 |
+
|
689 |
+
max_exact = num_buckets // 2
|
690 |
+
is_small = n < max_exact
|
691 |
+
|
692 |
+
val_if_large = max_exact + (
|
693 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
694 |
+
).long()
|
695 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
696 |
+
|
697 |
+
ret += torch.where(is_small, n, val_if_large)
|
698 |
+
return ret
|
699 |
+
|
700 |
+
def forward(self, qlen, klen, device, dtype):
|
701 |
+
q_pos = torch.arange(qlen, dtype = torch.long, device = device)
|
702 |
+
k_pos = torch.arange(klen, dtype = torch.long, device = device)
|
703 |
+
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
704 |
+
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
705 |
+
values = self.relative_attention_bias(rp_bucket)
|
706 |
+
values = values.to(device, dtype)
|
707 |
+
return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
|
708 |
+
|
709 |
+
|
710 |
+
class RotaryEmbAttnProcessor2_0:
|
711 |
+
r"""
|
712 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
713 |
+
Add rotary embedding support
|
714 |
+
"""
|
715 |
+
|
716 |
+
def __init__(self):
|
717 |
+
|
718 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
719 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
720 |
+
|
721 |
+
def __call__(
|
722 |
+
self,
|
723 |
+
attn: Attention,
|
724 |
+
hidden_states,
|
725 |
+
encoder_hidden_states=None,
|
726 |
+
attention_mask=None,
|
727 |
+
temb=None,
|
728 |
+
scale: float = 1.0,
|
729 |
+
key_pos_idx: Optional[torch.Tensor] = None,
|
730 |
+
):
|
731 |
+
assert attention_mask is None
|
732 |
+
residual = hidden_states
|
733 |
+
|
734 |
+
if attn.spatial_norm is not None:
|
735 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
736 |
+
|
737 |
+
input_ndim = hidden_states.ndim
|
738 |
+
|
739 |
+
if input_ndim == 4:
|
740 |
+
batch_size, channel, height, width = hidden_states.shape
|
741 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
742 |
+
|
743 |
+
batch_size, sequence_length, _ = (
|
744 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
745 |
+
)
|
746 |
+
|
747 |
+
# if attention_mask is not None:
|
748 |
+
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
749 |
+
# # scaled_dot_product_attention expects attention_mask shape to be
|
750 |
+
# # (batch, heads, source_length, target_length)
|
751 |
+
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
752 |
+
|
753 |
+
if attn.group_norm is not None:
|
754 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
755 |
+
|
756 |
+
query = attn.to_q(hidden_states, scale=scale)
|
757 |
+
|
758 |
+
if encoder_hidden_states is None:
|
759 |
+
encoder_hidden_states = hidden_states
|
760 |
+
elif attn.norm_cross:
|
761 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
762 |
+
|
763 |
+
qlen = hidden_states.shape[1]
|
764 |
+
klen = encoder_hidden_states.shape[1]
|
765 |
+
# currently only add bias for self attention. Relative distance doesn't make sense for cross attention.
|
766 |
+
# if qlen == klen:
|
767 |
+
# time_rel_pos_bias = attn.rotary_bias(qlen, klen, device=hidden_states.device, dtype=hidden_states.dtype)
|
768 |
+
# attention_mask = repeat(time_rel_pos_bias, "h d1 d2 -> b h d1 d2", b=batch_size)
|
769 |
+
|
770 |
+
key = attn.to_k(encoder_hidden_states, scale=scale)
|
771 |
+
value = attn.to_v(encoder_hidden_states, scale=scale)
|
772 |
+
|
773 |
+
query = attn.rotary_emb.rotate_queries_or_keys(query)
|
774 |
+
if qlen == klen:
|
775 |
+
key = attn.rotary_emb.rotate_queries_or_keys(key)
|
776 |
+
elif key_pos_idx is not None:
|
777 |
+
key = attn.rotary_emb.rotate_queries_or_keys(key, seq_pos=key_pos_idx)
|
778 |
+
|
779 |
+
inner_dim = key.shape[-1]
|
780 |
+
head_dim = inner_dim // attn.heads
|
781 |
+
|
782 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
783 |
+
|
784 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
785 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
786 |
+
|
787 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
788 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
789 |
+
hidden_states = F.scaled_dot_product_attention(
|
790 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
791 |
+
)
|
792 |
+
|
793 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
794 |
+
hidden_states = hidden_states.to(query.dtype)
|
795 |
+
|
796 |
+
# linear proj
|
797 |
+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
798 |
+
# dropout
|
799 |
+
hidden_states = attn.to_out[1](hidden_states)
|
800 |
+
|
801 |
+
if input_ndim == 4:
|
802 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
803 |
+
|
804 |
+
if attn.residual_connection:
|
805 |
+
hidden_states = hidden_states + residual
|
806 |
+
|
807 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
808 |
+
|
809 |
+
return hidden_states
|
src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_transformer_blocks.py
ADDED
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/v0.21.0/src/diffusers/models/transformer_2d.py
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
12 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
13 |
+
from diffusers.utils import BaseOutput, deprecate
|
14 |
+
from diffusers.models.attention import AdaLayerNorm, AdaLayerNormZero, FeedForward, GatedSelfAttentionDense
|
15 |
+
from diffusers.models.embeddings import PatchEmbed
|
16 |
+
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
17 |
+
from diffusers.models.modeling_utils import ModelMixin
|
18 |
+
from diffusers.models.transformer_2d import Transformer2DModelOutput
|
19 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
20 |
+
from diffusers.models.attention_processor import Attention
|
21 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
22 |
+
|
23 |
+
from .videoldm_attention import ConditionalAttention, TemporalConditionalAttention
|
24 |
+
|
25 |
+
|
26 |
+
class Transformer2DConditionModel(ModelMixin, ConfigMixin):
|
27 |
+
@register_to_config
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
num_attention_heads: int = 16,
|
31 |
+
attention_head_dim: int = 88,
|
32 |
+
in_channels: Optional[int] = None,
|
33 |
+
out_channels: Optional[int] = None,
|
34 |
+
num_layers: int = 1,
|
35 |
+
dropout: float = 0.0,
|
36 |
+
norm_num_groups: int = 32,
|
37 |
+
cross_attention_dim: Optional[int] = None,
|
38 |
+
attention_bias: bool = False,
|
39 |
+
sample_size: Optional[int] = None,
|
40 |
+
num_vector_embeds: Optional[int] = None,
|
41 |
+
patch_size: Optional[int] = None,
|
42 |
+
activation_fn: str = "geglu",
|
43 |
+
num_embeds_ada_norm: Optional[int] = None,
|
44 |
+
use_linear_projection: bool = False,
|
45 |
+
only_cross_attention: bool = False,
|
46 |
+
double_self_attention: bool = False,
|
47 |
+
upcast_attention: bool = False,
|
48 |
+
norm_type: str = "layer_norm",
|
49 |
+
norm_elementwise_affine: bool = True,
|
50 |
+
attention_type: str = "default",
|
51 |
+
# additional
|
52 |
+
n_frames: int = 8,
|
53 |
+
is_temporal: bool = False,
|
54 |
+
augment_temporal_attention: bool = False,
|
55 |
+
rotary_emb=False,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.use_linear_projection = use_linear_projection
|
59 |
+
self.num_attention_heads = num_attention_heads
|
60 |
+
self.attention_head_dim = attention_head_dim
|
61 |
+
inner_dim = num_attention_heads * attention_head_dim
|
62 |
+
|
63 |
+
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
64 |
+
# Define whether input is continuous or discrete depending on configuration
|
65 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
66 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
67 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
68 |
+
|
69 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
70 |
+
deprecation_message = (
|
71 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
72 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
73 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
74 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
75 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
76 |
+
)
|
77 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
78 |
+
norm_type = "ada_norm"
|
79 |
+
|
80 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
81 |
+
raise ValueError(
|
82 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
83 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
84 |
+
)
|
85 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
86 |
+
raise ValueError(
|
87 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
88 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
89 |
+
)
|
90 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
91 |
+
raise ValueError(
|
92 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
93 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
94 |
+
)
|
95 |
+
|
96 |
+
# 2. Define input layers
|
97 |
+
if self.is_input_continuous:
|
98 |
+
self.in_channels = in_channels
|
99 |
+
|
100 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
101 |
+
if use_linear_projection:
|
102 |
+
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
|
103 |
+
else:
|
104 |
+
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
105 |
+
elif self.is_input_vectorized:
|
106 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
107 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
108 |
+
|
109 |
+
self.height = sample_size
|
110 |
+
self.width = sample_size
|
111 |
+
self.num_vector_embeds = num_vector_embeds
|
112 |
+
self.num_latent_pixels = self.height * self.width
|
113 |
+
|
114 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
115 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
116 |
+
)
|
117 |
+
elif self.is_input_patches:
|
118 |
+
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
119 |
+
|
120 |
+
self.height = sample_size
|
121 |
+
self.width = sample_size
|
122 |
+
|
123 |
+
self.patch_size = patch_size
|
124 |
+
self.pos_embed = PatchEmbed(
|
125 |
+
height=sample_size,
|
126 |
+
width=sample_size,
|
127 |
+
patch_size=patch_size,
|
128 |
+
in_channels=in_channels,
|
129 |
+
embed_dim=inner_dim,
|
130 |
+
)
|
131 |
+
|
132 |
+
# 3. Define transformers blocks
|
133 |
+
self.transformer_blocks = nn.ModuleList(
|
134 |
+
[
|
135 |
+
BasicConditionalTransformerBlock(
|
136 |
+
inner_dim,
|
137 |
+
num_attention_heads,
|
138 |
+
attention_head_dim,
|
139 |
+
dropout=dropout,
|
140 |
+
cross_attention_dim=cross_attention_dim,
|
141 |
+
activation_fn=activation_fn,
|
142 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
143 |
+
attention_bias=attention_bias,
|
144 |
+
only_cross_attention=only_cross_attention,
|
145 |
+
double_self_attention=double_self_attention,
|
146 |
+
upcast_attention=upcast_attention,
|
147 |
+
norm_type=norm_type,
|
148 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
149 |
+
attention_type=attention_type,
|
150 |
+
# additional
|
151 |
+
n_frames=n_frames,
|
152 |
+
is_temporal=is_temporal,
|
153 |
+
augment_temporal_attention=augment_temporal_attention,
|
154 |
+
rotary_emb=rotary_emb,
|
155 |
+
)
|
156 |
+
for d in range(num_layers)
|
157 |
+
]
|
158 |
+
)
|
159 |
+
|
160 |
+
# 4. Define output layers
|
161 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
162 |
+
if self.is_input_continuous:
|
163 |
+
# TODO: should use out_channels for continuous projections
|
164 |
+
if use_linear_projection:
|
165 |
+
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
|
166 |
+
else:
|
167 |
+
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
168 |
+
elif self.is_input_vectorized:
|
169 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
170 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
171 |
+
elif self.is_input_patches:
|
172 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
173 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
174 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
175 |
+
|
176 |
+
self.alpha = None
|
177 |
+
if is_temporal:
|
178 |
+
self.alpha = nn.Parameter(torch.ones(1))
|
179 |
+
|
180 |
+
self.gradient_checkpointing = False
|
181 |
+
|
182 |
+
def forward(
|
183 |
+
self,
|
184 |
+
hidden_states: torch.Tensor,
|
185 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
186 |
+
timestep: Optional[torch.LongTensor] = None,
|
187 |
+
class_labels: Optional[torch.LongTensor] = None,
|
188 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
189 |
+
attention_mask: Optional[torch.Tensor] = None,
|
190 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
191 |
+
return_dict: bool = True,
|
192 |
+
condition_on_first_frame: bool = False,
|
193 |
+
):
|
194 |
+
input_states = hidden_states
|
195 |
+
input_height, input_width = hidden_states.shape[-2:]
|
196 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
197 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
198 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
199 |
+
# expects mask of shape:
|
200 |
+
# [batch, key_tokens]
|
201 |
+
# adds singleton query_tokens dimension:
|
202 |
+
# [batch, 1, key_tokens]
|
203 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
204 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
205 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
206 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
207 |
+
# assume that mask is expressed as:
|
208 |
+
# (1 = keep, 0 = discard)
|
209 |
+
# convert mask into a bias that can be added to attention scores:
|
210 |
+
# (keep = +0, discard = -10000.0)
|
211 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
212 |
+
attention_mask = attention_mask.unsqueeze(1)
|
213 |
+
|
214 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
215 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
216 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
217 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
218 |
+
|
219 |
+
# Retrieve lora scale.
|
220 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
221 |
+
|
222 |
+
# 1. Input
|
223 |
+
if self.is_input_continuous:
|
224 |
+
batch, _, height, width = hidden_states.shape
|
225 |
+
residual = hidden_states
|
226 |
+
|
227 |
+
hidden_states = self.norm(hidden_states)
|
228 |
+
if not self.use_linear_projection:
|
229 |
+
hidden_states = self.proj_in(hidden_states, lora_scale)
|
230 |
+
inner_dim = hidden_states.shape[1]
|
231 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
232 |
+
else:
|
233 |
+
inner_dim = hidden_states.shape[1]
|
234 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
235 |
+
hidden_states = self.proj_in(hidden_states, scale=lora_scale)
|
236 |
+
|
237 |
+
elif self.is_input_vectorized:
|
238 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
239 |
+
elif self.is_input_patches:
|
240 |
+
hidden_states = self.pos_embed(hidden_states)
|
241 |
+
|
242 |
+
# 2. Blocks
|
243 |
+
for block in self.transformer_blocks:
|
244 |
+
if self.training and self.gradient_checkpointing:
|
245 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
246 |
+
block,
|
247 |
+
hidden_states,
|
248 |
+
attention_mask,
|
249 |
+
encoder_hidden_states,
|
250 |
+
encoder_attention_mask,
|
251 |
+
timestep,
|
252 |
+
cross_attention_kwargs,
|
253 |
+
class_labels,
|
254 |
+
use_reentrant=False,
|
255 |
+
)
|
256 |
+
else:
|
257 |
+
hidden_states = block(
|
258 |
+
hidden_states,
|
259 |
+
attention_mask=attention_mask,
|
260 |
+
encoder_hidden_states=encoder_hidden_states,
|
261 |
+
encoder_attention_mask=encoder_attention_mask,
|
262 |
+
timestep=timestep,
|
263 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
264 |
+
class_labels=class_labels,
|
265 |
+
# additional
|
266 |
+
condition_on_first_frame=condition_on_first_frame,
|
267 |
+
input_height=input_height,
|
268 |
+
input_width=input_width,
|
269 |
+
)
|
270 |
+
|
271 |
+
# 3. Output
|
272 |
+
if self.is_input_continuous:
|
273 |
+
if not self.use_linear_projection:
|
274 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
275 |
+
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
|
276 |
+
else:
|
277 |
+
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
|
278 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
279 |
+
|
280 |
+
output = hidden_states + residual
|
281 |
+
elif self.is_input_vectorized:
|
282 |
+
hidden_states = self.norm_out(hidden_states)
|
283 |
+
logits = self.out(hidden_states)
|
284 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
285 |
+
logits = logits.permute(0, 2, 1)
|
286 |
+
|
287 |
+
# log(p(x_0))
|
288 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
289 |
+
elif self.is_input_patches:
|
290 |
+
# TODO: cleanup!
|
291 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
292 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
293 |
+
)
|
294 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
295 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
296 |
+
hidden_states = self.proj_out_2(hidden_states)
|
297 |
+
|
298 |
+
# unpatchify
|
299 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
300 |
+
hidden_states = hidden_states.reshape(
|
301 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
302 |
+
)
|
303 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
304 |
+
output = hidden_states.reshape(
|
305 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
306 |
+
)
|
307 |
+
|
308 |
+
if self.alpha is not None:
|
309 |
+
with torch.no_grad():
|
310 |
+
self.alpha.clamp_(0, 1)
|
311 |
+
|
312 |
+
output = self.alpha * input_states + (1 - self.alpha) * output
|
313 |
+
|
314 |
+
if not return_dict:
|
315 |
+
return (output,)
|
316 |
+
|
317 |
+
return Transformer2DModelOutput(sample=output)
|
318 |
+
|
319 |
+
|
320 |
+
@maybe_allow_in_graph
|
321 |
+
class BasicConditionalTransformerBlock(nn.Module):
|
322 |
+
""" transformer block with first frame conditioning """
|
323 |
+
def __init__(
|
324 |
+
self,
|
325 |
+
dim: int,
|
326 |
+
num_attention_heads: int,
|
327 |
+
attention_head_dim: int,
|
328 |
+
dropout=0.0,
|
329 |
+
cross_attention_dim: Optional[int] = None,
|
330 |
+
activation_fn: str = "geglu",
|
331 |
+
num_embeds_ada_norm: Optional[int] = None,
|
332 |
+
attention_bias: bool = False,
|
333 |
+
only_cross_attention: bool = False,
|
334 |
+
double_self_attention: bool = False,
|
335 |
+
upcast_attention: bool = False,
|
336 |
+
norm_elementwise_affine: bool = True,
|
337 |
+
norm_type: str = "layer_norm",
|
338 |
+
final_dropout: bool = False,
|
339 |
+
attention_type: str = "default",
|
340 |
+
# additional
|
341 |
+
n_frames: int = 8,
|
342 |
+
is_temporal: bool = False,
|
343 |
+
augment_temporal_attention: bool = False,
|
344 |
+
rotary_emb=False,
|
345 |
+
):
|
346 |
+
super().__init__()
|
347 |
+
self.n_frames = n_frames
|
348 |
+
self.only_cross_attention = only_cross_attention
|
349 |
+
self.augment_temporal_attention = augment_temporal_attention
|
350 |
+
self.is_temporal = is_temporal
|
351 |
+
|
352 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
353 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
354 |
+
|
355 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
356 |
+
raise ValueError(
|
357 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
358 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
359 |
+
)
|
360 |
+
|
361 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
362 |
+
# 1. Self-Attn
|
363 |
+
if self.use_ada_layer_norm:
|
364 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
365 |
+
elif self.use_ada_layer_norm_zero:
|
366 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
367 |
+
else:
|
368 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
369 |
+
|
370 |
+
if not is_temporal:
|
371 |
+
self.attn1 = ConditionalAttention(
|
372 |
+
query_dim=dim,
|
373 |
+
heads=num_attention_heads,
|
374 |
+
dim_head=attention_head_dim,
|
375 |
+
dropout=dropout,
|
376 |
+
bias=attention_bias,
|
377 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
378 |
+
upcast_attention=upcast_attention,
|
379 |
+
)
|
380 |
+
else:
|
381 |
+
self.attn1 = TemporalConditionalAttention(
|
382 |
+
query_dim=dim,
|
383 |
+
heads=num_attention_heads,
|
384 |
+
dim_head=attention_head_dim,
|
385 |
+
dropout=dropout,
|
386 |
+
bias=attention_bias,
|
387 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
388 |
+
upcast_attention=upcast_attention,
|
389 |
+
# additional
|
390 |
+
n_frames=n_frames,
|
391 |
+
rotary_emb=rotary_emb,
|
392 |
+
)
|
393 |
+
|
394 |
+
# 2. Cross-Attn
|
395 |
+
if cross_attention_dim is not None or double_self_attention:
|
396 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
397 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
398 |
+
# the second cross attention block.
|
399 |
+
self.norm2 = (
|
400 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
401 |
+
if self.use_ada_layer_norm
|
402 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
403 |
+
)
|
404 |
+
if not is_temporal:
|
405 |
+
self.attn2 = ConditionalAttention(
|
406 |
+
query_dim=dim,
|
407 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
408 |
+
heads=num_attention_heads,
|
409 |
+
dim_head=attention_head_dim,
|
410 |
+
dropout=dropout,
|
411 |
+
bias=attention_bias,
|
412 |
+
upcast_attention=upcast_attention,
|
413 |
+
) # is self-attn if encoder_hidden_states is none
|
414 |
+
else:
|
415 |
+
self.attn2 = TemporalConditionalAttention(
|
416 |
+
query_dim=dim,
|
417 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
418 |
+
heads=num_attention_heads,
|
419 |
+
dim_head=attention_head_dim,
|
420 |
+
dropout=dropout,
|
421 |
+
bias=attention_bias,
|
422 |
+
upcast_attention=upcast_attention,
|
423 |
+
# additional
|
424 |
+
n_frames=n_frames,
|
425 |
+
rotary_emb=rotary_emb,
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
self.norm2 = None
|
429 |
+
self.attn2 = None
|
430 |
+
|
431 |
+
# 3. Feed-forward
|
432 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
433 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
434 |
+
|
435 |
+
# 4. Fuser
|
436 |
+
if attention_type == "gated" or attention_type == "gated-text-image":
|
437 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
438 |
+
|
439 |
+
# let chunk size default to None
|
440 |
+
self._chunk_size = None
|
441 |
+
self._chunk_dim = 0
|
442 |
+
|
443 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
444 |
+
# Sets chunk feed-forward
|
445 |
+
self._chunk_size = chunk_size
|
446 |
+
self._chunk_dim = dim
|
447 |
+
|
448 |
+
def forward(
|
449 |
+
self,
|
450 |
+
hidden_states: torch.FloatTensor,
|
451 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
452 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
453 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
454 |
+
timestep: Optional[torch.LongTensor] = None,
|
455 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
456 |
+
class_labels: Optional[torch.LongTensor] = None,
|
457 |
+
condition_on_first_frame: bool = False,
|
458 |
+
input_height: Optional[int] = None,
|
459 |
+
input_width: Optional[int] = None,
|
460 |
+
):
|
461 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
462 |
+
# 0. Self-Attention
|
463 |
+
if self.use_ada_layer_norm:
|
464 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
465 |
+
elif self.use_ada_layer_norm_zero:
|
466 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
467 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
468 |
+
)
|
469 |
+
else:
|
470 |
+
norm_hidden_states = self.norm1(hidden_states)
|
471 |
+
|
472 |
+
# 1. Retrieve lora scale.
|
473 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
474 |
+
|
475 |
+
# 2. Prepare GLIGEN inputs
|
476 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
477 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
478 |
+
|
479 |
+
if condition_on_first_frame:
|
480 |
+
first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :]
|
481 |
+
first_frame_hidden_states = repeat(first_frame_hidden_states, 'b d h -> b f d h', f=self.n_frames)
|
482 |
+
first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b f d h -> (b f) d h')
|
483 |
+
first_frame_concat_hidden_states = torch.cat((norm_hidden_states, first_frame_hidden_states), dim=1)
|
484 |
+
attn_output = self.attn1(
|
485 |
+
norm_hidden_states,
|
486 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else first_frame_concat_hidden_states,
|
487 |
+
attention_mask=attention_mask,
|
488 |
+
**cross_attention_kwargs,
|
489 |
+
)
|
490 |
+
elif self.is_temporal and self.augment_temporal_attention:
|
491 |
+
first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :]
|
492 |
+
first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b (h w) c -> b h w c', h=input_height, w=input_width)
|
493 |
+
first_frame_hidden_states = first_frame_hidden_states.permute(0, 3, 1, 2)
|
494 |
+
padded_first_frame = torch.nn.functional.pad(first_frame_hidden_states, (1, 1, 1, 1), "replicate")
|
495 |
+
first_frame_windows = padded_first_frame.unfold(2, 3, 1).unfold(3, 3, 1)
|
496 |
+
mask = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.bool)
|
497 |
+
adjacent_slices = first_frame_windows[:, :, :, :, mask]
|
498 |
+
attn_output = self.attn1(
|
499 |
+
norm_hidden_states,
|
500 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
501 |
+
attention_mask=attention_mask,
|
502 |
+
adjacent_slices=adjacent_slices,
|
503 |
+
**cross_attention_kwargs,
|
504 |
+
)
|
505 |
+
else:
|
506 |
+
attn_output = self.attn1(
|
507 |
+
norm_hidden_states,
|
508 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
509 |
+
attention_mask=attention_mask,
|
510 |
+
**cross_attention_kwargs,
|
511 |
+
)
|
512 |
+
if self.use_ada_layer_norm_zero:
|
513 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
514 |
+
hidden_states = attn_output + hidden_states
|
515 |
+
|
516 |
+
# 2.5 GLIGEN Control
|
517 |
+
if gligen_kwargs is not None:
|
518 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
519 |
+
# 2.5 ends
|
520 |
+
|
521 |
+
# 3. Cross-Attention
|
522 |
+
if self.attn2 is not None:
|
523 |
+
norm_hidden_states = (
|
524 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
525 |
+
)
|
526 |
+
|
527 |
+
attn_output = self.attn2(
|
528 |
+
norm_hidden_states,
|
529 |
+
encoder_hidden_states=encoder_hidden_states,
|
530 |
+
attention_mask=encoder_attention_mask,
|
531 |
+
**cross_attention_kwargs,
|
532 |
+
)
|
533 |
+
hidden_states = attn_output + hidden_states
|
534 |
+
|
535 |
+
# 4. Feed-forward
|
536 |
+
norm_hidden_states = self.norm3(hidden_states)
|
537 |
+
|
538 |
+
if self.use_ada_layer_norm_zero:
|
539 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
540 |
+
|
541 |
+
if self._chunk_size is not None:
|
542 |
+
# "feed_forward_chunk_size" can be used to save memory
|
543 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
544 |
+
raise ValueError(
|
545 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
546 |
+
)
|
547 |
+
|
548 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
549 |
+
ff_output = torch.cat(
|
550 |
+
[
|
551 |
+
self.ff(hid_slice, scale=lora_scale)
|
552 |
+
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
|
553 |
+
],
|
554 |
+
dim=self._chunk_dim,
|
555 |
+
)
|
556 |
+
else:
|
557 |
+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
558 |
+
|
559 |
+
if self.use_ada_layer_norm_zero:
|
560 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
561 |
+
|
562 |
+
hidden_states = ff_output + hidden_states
|
563 |
+
|
564 |
+
return hidden_states
|
src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet.py
ADDED
@@ -0,0 +1,1371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import Optional, Tuple, Union, Dict, List, Any
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
9 |
+
from diffusers.models import ModelMixin
|
10 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
11 |
+
from diffusers.models.unet_2d_blocks import UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn
|
12 |
+
from diffusers.models.embeddings import (
|
13 |
+
GaussianFourierProjection,
|
14 |
+
ImageHintTimeEmbedding,
|
15 |
+
ImageProjection,
|
16 |
+
ImageTimeEmbedding,
|
17 |
+
PositionNet,
|
18 |
+
TextImageProjection,
|
19 |
+
TextImageTimeEmbedding,
|
20 |
+
TextTimeEmbedding,
|
21 |
+
TimestepEmbedding,
|
22 |
+
Timesteps,
|
23 |
+
)
|
24 |
+
from diffusers.models.attention_processor import (
|
25 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
26 |
+
CROSS_ATTENTION_PROCESSORS,
|
27 |
+
AttentionProcessor,
|
28 |
+
AttnAddedKVProcessor,
|
29 |
+
AttnProcessor,
|
30 |
+
)
|
31 |
+
from diffusers.models.activations import get_activation
|
32 |
+
from diffusers.configuration_utils import register_to_config, ConfigMixin
|
33 |
+
from diffusers.models.modeling_utils import load_state_dict, load_model_dict_into_meta
|
34 |
+
from diffusers.utils import (
|
35 |
+
CONFIG_NAME,
|
36 |
+
DIFFUSERS_CACHE,
|
37 |
+
FLAX_WEIGHTS_NAME,
|
38 |
+
HF_HUB_OFFLINE,
|
39 |
+
SAFETENSORS_WEIGHTS_NAME,
|
40 |
+
WEIGHTS_NAME,
|
41 |
+
_add_variant,
|
42 |
+
_get_model_file,
|
43 |
+
deprecate,
|
44 |
+
is_accelerate_available,
|
45 |
+
is_torch_version,
|
46 |
+
logging,
|
47 |
+
)
|
48 |
+
from diffusers import __version__
|
49 |
+
|
50 |
+
if is_torch_version(">=", "1.9.0"):
|
51 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
52 |
+
else:
|
53 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
54 |
+
|
55 |
+
|
56 |
+
if is_accelerate_available():
|
57 |
+
import accelerate
|
58 |
+
from accelerate.utils import set_module_tensor_to_device
|
59 |
+
from accelerate.utils.versions import is_torch_version
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
from .videoldm_unet_blocks import get_down_block, get_up_block, VideoLDMUNetMidBlock2DCrossAttn
|
64 |
+
|
65 |
+
logger = logging.get_logger(__name__)
|
66 |
+
|
67 |
+
|
68 |
+
class VideoLDMUNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
69 |
+
_supports_gradient_checkpointing = True
|
70 |
+
@register_to_config
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
sample_size: Optional[int] = None,
|
74 |
+
in_channels: int = 4,
|
75 |
+
out_channels: int = 4,
|
76 |
+
center_input_sample: bool = False,
|
77 |
+
flip_sin_to_cos: bool = True,
|
78 |
+
freq_shift: int = 0,
|
79 |
+
down_block_types: Tuple[str] = (
|
80 |
+
"CrossAttnDownBlock2D", # -> VideoLDMDownBlock
|
81 |
+
"CrossAttnDownBlock2D", # -> VideoLDMDownBlock
|
82 |
+
"CrossAttnDownBlock2D", # -> VideoLDMDownBlock
|
83 |
+
"DownBlock2D",
|
84 |
+
),
|
85 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
86 |
+
up_block_types: Tuple[str] = (
|
87 |
+
"UpBlock2D",
|
88 |
+
"CrossAttnUpBlock2D", # -> VideoLDMUpBlock
|
89 |
+
"CrossAttnUpBlock2D", # -> VideoLDMUpBlock
|
90 |
+
"CrossAttnUpBlock2D", # -> VideoLDMUpBlock
|
91 |
+
),
|
92 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
93 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
94 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
95 |
+
downsample_padding: int = 1,
|
96 |
+
mid_block_scale_factor: float = 1,
|
97 |
+
dropout: float = 0.0,
|
98 |
+
act_fn: str = "silu",
|
99 |
+
norm_num_groups: Optional[int] = 32,
|
100 |
+
norm_eps: float = 1e-5,
|
101 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
102 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
103 |
+
encoder_hid_dim: Optional[int] = None,
|
104 |
+
encoder_hid_dim_type: Optional[str] = None,
|
105 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
106 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
107 |
+
dual_cross_attention: bool = False,
|
108 |
+
use_linear_projection: bool = False,
|
109 |
+
class_embed_type: Optional[str] = None,
|
110 |
+
addition_embed_type: Optional[str] = None,
|
111 |
+
addition_time_embed_dim: Optional[int] = None,
|
112 |
+
num_class_embeds: Optional[int] = None,
|
113 |
+
upcast_attention: bool = False,
|
114 |
+
resnet_time_scale_shift: str = "default",
|
115 |
+
resnet_skip_time_act: bool = False,
|
116 |
+
resnet_out_scale_factor: int = 1.0,
|
117 |
+
time_embedding_type: str = "positional",
|
118 |
+
time_embedding_dim: Optional[int] = None,
|
119 |
+
time_embedding_act_fn: Optional[str] = None,
|
120 |
+
timestep_post_act: Optional[str] = None,
|
121 |
+
time_cond_proj_dim: Optional[int] = None,
|
122 |
+
conv_in_kernel: int = 3,
|
123 |
+
conv_out_kernel: int = 3,
|
124 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
125 |
+
attention_type: str = "default",
|
126 |
+
class_embeddings_concat: bool = False,
|
127 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
128 |
+
cross_attention_norm: Optional[str] = None,
|
129 |
+
addition_embed_type_num_heads=64,
|
130 |
+
# additional
|
131 |
+
use_temporal: bool = True,
|
132 |
+
n_frames: int = 8,
|
133 |
+
n_temp_heads: int = 8,
|
134 |
+
first_frame_condition_mode: str = "none",
|
135 |
+
augment_temporal_attention: bool = False,
|
136 |
+
temp_pos_embedding: str = "sinusoidal",
|
137 |
+
use_frame_stride_condition: bool = False,
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
rotary_emb = False
|
142 |
+
if temp_pos_embedding == "rotary":
|
143 |
+
# from rotary_embedding_torch import RotaryEmbedding
|
144 |
+
# rotary_emb = RotaryEmbedding(32)
|
145 |
+
# self.rotary_emb = rotary_emb
|
146 |
+
rotary_emb = True
|
147 |
+
self.rotary_emb = rotary_emb
|
148 |
+
|
149 |
+
self.use_temporal = use_temporal
|
150 |
+
self.augment_temporal_attention = augment_temporal_attention
|
151 |
+
|
152 |
+
assert first_frame_condition_mode in ["none", "concat", "conv2d", "input_only"], f"first_frame_condition_mode: {first_frame_condition_mode} must be one of ['none', 'concat', 'conv2d', 'input_only']"
|
153 |
+
self.first_frame_condition_mode = first_frame_condition_mode
|
154 |
+
latent_channels = in_channels
|
155 |
+
|
156 |
+
self.sample_size = sample_size
|
157 |
+
|
158 |
+
if num_attention_heads is not None:
|
159 |
+
raise ValueError(
|
160 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
161 |
+
)
|
162 |
+
|
163 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
164 |
+
|
165 |
+
# Check inputs
|
166 |
+
if len(down_block_types) != len(up_block_types):
|
167 |
+
raise ValueError(
|
168 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
169 |
+
)
|
170 |
+
|
171 |
+
if len(block_out_channels) != len(down_block_types):
|
172 |
+
raise ValueError(
|
173 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
174 |
+
)
|
175 |
+
|
176 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
177 |
+
raise ValueError(
|
178 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
179 |
+
)
|
180 |
+
|
181 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
182 |
+
raise ValueError(
|
183 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
184 |
+
)
|
185 |
+
|
186 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
187 |
+
raise ValueError(
|
188 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
189 |
+
)
|
190 |
+
|
191 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
192 |
+
raise ValueError(
|
193 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
194 |
+
)
|
195 |
+
|
196 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
197 |
+
raise ValueError(
|
198 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
199 |
+
)
|
200 |
+
|
201 |
+
# input
|
202 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
203 |
+
self.conv_in = nn.Conv2d(
|
204 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
205 |
+
)
|
206 |
+
|
207 |
+
# time
|
208 |
+
if time_embedding_type == "fourier":
|
209 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
210 |
+
if time_embed_dim % 2 != 0:
|
211 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
212 |
+
self.time_proj = GaussianFourierProjection(
|
213 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
214 |
+
)
|
215 |
+
timestep_input_dim = time_embed_dim
|
216 |
+
elif time_embedding_type == "positional":
|
217 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
218 |
+
|
219 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
220 |
+
timestep_input_dim = block_out_channels[0]
|
221 |
+
else:
|
222 |
+
raise ValueError(
|
223 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
224 |
+
)
|
225 |
+
|
226 |
+
self.time_embedding = TimestepEmbedding(
|
227 |
+
timestep_input_dim,
|
228 |
+
time_embed_dim,
|
229 |
+
act_fn=act_fn,
|
230 |
+
post_act_fn=timestep_post_act,
|
231 |
+
cond_proj_dim=time_cond_proj_dim,
|
232 |
+
)
|
233 |
+
|
234 |
+
self.use_frame_stride_condition = use_frame_stride_condition
|
235 |
+
if self.use_frame_stride_condition:
|
236 |
+
self.frame_stride_embedding = TimestepEmbedding(
|
237 |
+
timestep_input_dim,
|
238 |
+
time_embed_dim,
|
239 |
+
act_fn=act_fn,
|
240 |
+
post_act_fn=timestep_post_act,
|
241 |
+
cond_proj_dim=time_cond_proj_dim,
|
242 |
+
)
|
243 |
+
# zero init
|
244 |
+
nn.init.zeros_(self.frame_stride_embedding.linear_2.weight)
|
245 |
+
nn.init.zeros_(self.frame_stride_embedding.linear_2.bias)
|
246 |
+
|
247 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
248 |
+
encoder_hid_dim_type = "text_proj"
|
249 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
250 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
251 |
+
|
252 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
253 |
+
raise ValueError(
|
254 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
255 |
+
)
|
256 |
+
|
257 |
+
if encoder_hid_dim_type == "text_proj":
|
258 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
259 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
260 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
261 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
262 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
263 |
+
self.encoder_hid_proj = TextImageProjection(
|
264 |
+
text_embed_dim=encoder_hid_dim,
|
265 |
+
image_embed_dim=cross_attention_dim,
|
266 |
+
cross_attention_dim=cross_attention_dim,
|
267 |
+
)
|
268 |
+
elif encoder_hid_dim_type == "image_proj":
|
269 |
+
# Kandinsky 2.2
|
270 |
+
self.encoder_hid_proj = ImageProjection(
|
271 |
+
image_embed_dim=encoder_hid_dim,
|
272 |
+
cross_attention_dim=cross_attention_dim,
|
273 |
+
)
|
274 |
+
elif encoder_hid_dim_type is not None:
|
275 |
+
raise ValueError(
|
276 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
self.encoder_hid_proj = None
|
280 |
+
|
281 |
+
# class embedding
|
282 |
+
if class_embed_type is None and num_class_embeds is not None:
|
283 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
284 |
+
elif class_embed_type == "timestep":
|
285 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
286 |
+
elif class_embed_type == "identity":
|
287 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
288 |
+
elif class_embed_type == "projection":
|
289 |
+
if projection_class_embeddings_input_dim is None:
|
290 |
+
raise ValueError(
|
291 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
292 |
+
)
|
293 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
294 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
295 |
+
# 2. it projects from an arbitrary input dimension.
|
296 |
+
#
|
297 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
298 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
299 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
300 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
301 |
+
elif class_embed_type == "simple_projection":
|
302 |
+
if projection_class_embeddings_input_dim is None:
|
303 |
+
raise ValueError(
|
304 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
305 |
+
)
|
306 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
307 |
+
else:
|
308 |
+
self.class_embedding = None
|
309 |
+
|
310 |
+
if addition_embed_type == "text":
|
311 |
+
if encoder_hid_dim is not None:
|
312 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
313 |
+
else:
|
314 |
+
text_time_embedding_from_dim = cross_attention_dim
|
315 |
+
|
316 |
+
self.add_embedding = TextTimeEmbedding(
|
317 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
318 |
+
)
|
319 |
+
elif addition_embed_type == "text_image":
|
320 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
321 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
322 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
323 |
+
self.add_embedding = TextImageTimeEmbedding(
|
324 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
325 |
+
)
|
326 |
+
elif addition_embed_type == "text_time":
|
327 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
328 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
329 |
+
elif addition_embed_type == "image":
|
330 |
+
# Kandinsky 2.2
|
331 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
332 |
+
elif addition_embed_type == "image_hint":
|
333 |
+
# Kandinsky 2.2 ControlNet
|
334 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
335 |
+
elif addition_embed_type is not None:
|
336 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
337 |
+
|
338 |
+
if time_embedding_act_fn is None:
|
339 |
+
self.time_embed_act = None
|
340 |
+
else:
|
341 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
342 |
+
|
343 |
+
self.down_blocks = nn.ModuleList([])
|
344 |
+
self.up_blocks = nn.ModuleList([])
|
345 |
+
|
346 |
+
if isinstance(only_cross_attention, bool):
|
347 |
+
if mid_block_only_cross_attention is None:
|
348 |
+
mid_block_only_cross_attention = only_cross_attention
|
349 |
+
|
350 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
351 |
+
|
352 |
+
if mid_block_only_cross_attention is None:
|
353 |
+
mid_block_only_cross_attention = False
|
354 |
+
|
355 |
+
if isinstance(num_attention_heads, int):
|
356 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
357 |
+
|
358 |
+
if isinstance(attention_head_dim, int):
|
359 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
360 |
+
|
361 |
+
if isinstance(cross_attention_dim, int):
|
362 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
363 |
+
|
364 |
+
if isinstance(layers_per_block, int):
|
365 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
366 |
+
|
367 |
+
if isinstance(transformer_layers_per_block, int):
|
368 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
369 |
+
|
370 |
+
if class_embeddings_concat:
|
371 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
372 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
373 |
+
# regular time embeddings
|
374 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
375 |
+
else:
|
376 |
+
blocks_time_embed_dim = time_embed_dim
|
377 |
+
# down
|
378 |
+
output_channel = block_out_channels[0]
|
379 |
+
for i, down_block_type in enumerate(down_block_types):
|
380 |
+
input_channel = output_channel
|
381 |
+
output_channel = block_out_channels[i]
|
382 |
+
is_final_block = i == len(block_out_channels) - 1
|
383 |
+
|
384 |
+
down_block = get_down_block(
|
385 |
+
down_block_type,
|
386 |
+
num_layers=layers_per_block[i],
|
387 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
388 |
+
in_channels=input_channel,
|
389 |
+
out_channels=output_channel,
|
390 |
+
temb_channels=blocks_time_embed_dim,
|
391 |
+
add_downsample=not is_final_block,
|
392 |
+
resnet_eps=norm_eps,
|
393 |
+
resnet_act_fn=act_fn,
|
394 |
+
resnet_groups=norm_num_groups,
|
395 |
+
cross_attention_dim=cross_attention_dim[i],
|
396 |
+
num_attention_heads=num_attention_heads[i],
|
397 |
+
downsample_padding=downsample_padding,
|
398 |
+
dual_cross_attention=dual_cross_attention,
|
399 |
+
use_linear_projection=use_linear_projection,
|
400 |
+
only_cross_attention=only_cross_attention[i],
|
401 |
+
upcast_attention=upcast_attention,
|
402 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
403 |
+
attention_type=attention_type,
|
404 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
405 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
406 |
+
cross_attention_norm=cross_attention_norm,
|
407 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
408 |
+
dropout=dropout,
|
409 |
+
# additional
|
410 |
+
use_temporal=use_temporal,
|
411 |
+
augment_temporal_attention=augment_temporal_attention,
|
412 |
+
n_frames=n_frames,
|
413 |
+
n_temp_heads=n_temp_heads,
|
414 |
+
first_frame_condition_mode=first_frame_condition_mode,
|
415 |
+
latent_channels=latent_channels,
|
416 |
+
rotary_emb=rotary_emb,
|
417 |
+
)
|
418 |
+
self.down_blocks.append(down_block)
|
419 |
+
|
420 |
+
# mid
|
421 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
422 |
+
self.mid_block = VideoLDMUNetMidBlock2DCrossAttn(
|
423 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
424 |
+
in_channels=block_out_channels[-1],
|
425 |
+
temb_channels=blocks_time_embed_dim,
|
426 |
+
dropout=dropout,
|
427 |
+
resnet_eps=norm_eps,
|
428 |
+
resnet_act_fn=act_fn,
|
429 |
+
output_scale_factor=mid_block_scale_factor,
|
430 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
431 |
+
cross_attention_dim=cross_attention_dim[-1],
|
432 |
+
num_attention_heads=num_attention_heads[-1],
|
433 |
+
resnet_groups=norm_num_groups,
|
434 |
+
dual_cross_attention=dual_cross_attention,
|
435 |
+
use_linear_projection=use_linear_projection,
|
436 |
+
upcast_attention=upcast_attention,
|
437 |
+
attention_type=attention_type,
|
438 |
+
# additional
|
439 |
+
use_temporal=use_temporal,
|
440 |
+
n_frames=n_frames,
|
441 |
+
first_frame_condition_mode=first_frame_condition_mode,
|
442 |
+
latent_channels=latent_channels,
|
443 |
+
)
|
444 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
445 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
446 |
+
in_channels=block_out_channels[-1],
|
447 |
+
temb_channels=blocks_time_embed_dim,
|
448 |
+
dropout=dropout,
|
449 |
+
resnet_eps=norm_eps,
|
450 |
+
resnet_act_fn=act_fn,
|
451 |
+
output_scale_factor=mid_block_scale_factor,
|
452 |
+
cross_attention_dim=cross_attention_dim[-1],
|
453 |
+
attention_head_dim=attention_head_dim[-1],
|
454 |
+
resnet_groups=norm_num_groups,
|
455 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
456 |
+
skip_time_act=resnet_skip_time_act,
|
457 |
+
only_cross_attention=mid_block_only_cross_attention,
|
458 |
+
cross_attention_norm=cross_attention_norm,
|
459 |
+
)
|
460 |
+
elif mid_block_type is None:
|
461 |
+
self.mid_block = None
|
462 |
+
else:
|
463 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
464 |
+
|
465 |
+
# count how many layers upsample the images
|
466 |
+
self.num_upsamplers = 0
|
467 |
+
|
468 |
+
# up
|
469 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
470 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
471 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
472 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
473 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
474 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
475 |
+
|
476 |
+
output_channel = reversed_block_out_channels[0]
|
477 |
+
for i, up_block_type in enumerate(up_block_types):
|
478 |
+
is_final_block = i == len(block_out_channels) - 1
|
479 |
+
|
480 |
+
prev_output_channel = output_channel
|
481 |
+
output_channel = reversed_block_out_channels[i]
|
482 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
483 |
+
|
484 |
+
# add upsample block for all BUT final layer
|
485 |
+
if not is_final_block:
|
486 |
+
add_upsample = True
|
487 |
+
self.num_upsamplers += 1
|
488 |
+
else:
|
489 |
+
add_upsample = False
|
490 |
+
|
491 |
+
up_block = get_up_block(
|
492 |
+
up_block_type,
|
493 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
494 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
495 |
+
in_channels=input_channel,
|
496 |
+
out_channels=output_channel,
|
497 |
+
prev_output_channel=prev_output_channel,
|
498 |
+
temb_channels=blocks_time_embed_dim,
|
499 |
+
add_upsample=add_upsample,
|
500 |
+
resnet_eps=norm_eps,
|
501 |
+
resnet_act_fn=act_fn,
|
502 |
+
resnet_groups=norm_num_groups,
|
503 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
504 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
505 |
+
dual_cross_attention=dual_cross_attention,
|
506 |
+
use_linear_projection=use_linear_projection,
|
507 |
+
only_cross_attention=only_cross_attention[i],
|
508 |
+
upcast_attention=upcast_attention,
|
509 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
510 |
+
attention_type=attention_type,
|
511 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
512 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
513 |
+
cross_attention_norm=cross_attention_norm,
|
514 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
515 |
+
dropout=dropout,
|
516 |
+
# additional
|
517 |
+
use_temporal=use_temporal,
|
518 |
+
augment_temporal_attention=augment_temporal_attention,
|
519 |
+
n_frames=n_frames,
|
520 |
+
n_temp_heads=n_temp_heads,
|
521 |
+
first_frame_condition_mode=first_frame_condition_mode,
|
522 |
+
latent_channels=latent_channels,
|
523 |
+
rotary_emb=rotary_emb,
|
524 |
+
)
|
525 |
+
self.up_blocks.append(up_block)
|
526 |
+
prev_output_channel = output_channel
|
527 |
+
|
528 |
+
# out
|
529 |
+
if norm_num_groups is not None:
|
530 |
+
self.conv_norm_out = nn.GroupNorm(
|
531 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
532 |
+
)
|
533 |
+
|
534 |
+
self.conv_act = get_activation(act_fn)
|
535 |
+
|
536 |
+
else:
|
537 |
+
self.conv_norm_out = None
|
538 |
+
self.conv_act = None
|
539 |
+
|
540 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
541 |
+
self.conv_out = nn.Conv2d(
|
542 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
543 |
+
)
|
544 |
+
|
545 |
+
@property
|
546 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
547 |
+
r"""
|
548 |
+
Returns:
|
549 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
550 |
+
indexed by its weight name.
|
551 |
+
"""
|
552 |
+
# set recursively
|
553 |
+
processors = {}
|
554 |
+
|
555 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
556 |
+
if hasattr(module, "get_processor"):
|
557 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
558 |
+
|
559 |
+
for sub_name, child in module.named_children():
|
560 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
561 |
+
|
562 |
+
return processors
|
563 |
+
|
564 |
+
for name, module in self.named_children():
|
565 |
+
fn_recursive_add_processors(name, module, processors)
|
566 |
+
|
567 |
+
return processors
|
568 |
+
|
569 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
570 |
+
r"""
|
571 |
+
Sets the attention processor to use to compute attention.
|
572 |
+
|
573 |
+
Parameters:
|
574 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
575 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
576 |
+
for **all** `Attention` layers.
|
577 |
+
|
578 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
579 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
580 |
+
|
581 |
+
"""
|
582 |
+
count = len(self.attn_processors.keys())
|
583 |
+
|
584 |
+
if isinstance(processor, dict) and len(processor) != count:
|
585 |
+
raise ValueError(
|
586 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
587 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
588 |
+
)
|
589 |
+
|
590 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
591 |
+
if hasattr(module, "set_processor"):
|
592 |
+
if not isinstance(processor, dict):
|
593 |
+
module.set_processor(processor)
|
594 |
+
else:
|
595 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
596 |
+
|
597 |
+
for sub_name, child in module.named_children():
|
598 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
599 |
+
|
600 |
+
for name, module in self.named_children():
|
601 |
+
fn_recursive_attn_processor(name, module, processor)
|
602 |
+
|
603 |
+
def set_default_attn_processor(self):
|
604 |
+
"""
|
605 |
+
Disables custom attention processors and sets the default attention implementation.
|
606 |
+
"""
|
607 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
608 |
+
processor = AttnAddedKVProcessor()
|
609 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
610 |
+
processor = AttnProcessor()
|
611 |
+
else:
|
612 |
+
raise ValueError(
|
613 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
614 |
+
)
|
615 |
+
|
616 |
+
self.set_attn_processor(processor)
|
617 |
+
|
618 |
+
def set_attention_slice(self, slice_size):
|
619 |
+
r"""
|
620 |
+
Enable sliced attention computation.
|
621 |
+
|
622 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
623 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
624 |
+
|
625 |
+
Args:
|
626 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
627 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
628 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
629 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
630 |
+
must be a multiple of `slice_size`.
|
631 |
+
"""
|
632 |
+
sliceable_head_dims = []
|
633 |
+
|
634 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
635 |
+
if hasattr(module, "set_attention_slice"):
|
636 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
637 |
+
|
638 |
+
for child in module.children():
|
639 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
640 |
+
|
641 |
+
# retrieve number of attention layers
|
642 |
+
for module in self.children():
|
643 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
644 |
+
|
645 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
646 |
+
|
647 |
+
if slice_size == "auto":
|
648 |
+
# half the attention head size is usually a good trade-off between
|
649 |
+
# speed and memory
|
650 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
651 |
+
elif slice_size == "max":
|
652 |
+
# make smallest slice possible
|
653 |
+
slice_size = num_sliceable_layers * [1]
|
654 |
+
|
655 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
656 |
+
|
657 |
+
if len(slice_size) != len(sliceable_head_dims):
|
658 |
+
raise ValueError(
|
659 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
660 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
661 |
+
)
|
662 |
+
|
663 |
+
for i in range(len(slice_size)):
|
664 |
+
size = slice_size[i]
|
665 |
+
dim = sliceable_head_dims[i]
|
666 |
+
if size is not None and size > dim:
|
667 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
668 |
+
|
669 |
+
# Recursively walk through all the children.
|
670 |
+
# Any children which exposes the set_attention_slice method
|
671 |
+
# gets the message
|
672 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
673 |
+
if hasattr(module, "set_attention_slice"):
|
674 |
+
module.set_attention_slice(slice_size.pop())
|
675 |
+
|
676 |
+
for child in module.children():
|
677 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
678 |
+
|
679 |
+
reversed_slice_size = list(reversed(slice_size))
|
680 |
+
for module in self.children():
|
681 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
682 |
+
|
683 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
684 |
+
if hasattr(module, "gradient_checkpointing"):
|
685 |
+
module.gradient_checkpointing = value
|
686 |
+
|
687 |
+
def forward(
|
688 |
+
self,
|
689 |
+
sample: torch.FloatTensor,
|
690 |
+
timestep: Union[torch.Tensor, float, int],
|
691 |
+
encoder_hidden_states: torch.Tensor,
|
692 |
+
class_labels: Optional[torch.Tensor] = None,
|
693 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
694 |
+
attention_mask: Optional[torch.Tensor] = None,
|
695 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
696 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
697 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
698 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
699 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
700 |
+
return_dict: bool = True,
|
701 |
+
# additional
|
702 |
+
first_frame_latents: Optional[torch.Tensor] = None,
|
703 |
+
frame_stride: Optional[Union[torch.Tensor, float, int]] = None,
|
704 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
705 |
+
# reshape video data
|
706 |
+
assert sample.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={sample.dim()}."
|
707 |
+
video_length = sample.shape[2]
|
708 |
+
|
709 |
+
if first_frame_latents is not None:
|
710 |
+
assert self.config.first_frame_condition_mode != "none", "first_frame_latents is not None, but first_frame_condition_mode is 'none'."
|
711 |
+
|
712 |
+
if self.config.first_frame_condition_mode != "none":
|
713 |
+
sample = torch.cat([first_frame_latents, sample], dim=2)
|
714 |
+
video_length += 1
|
715 |
+
|
716 |
+
# copy conditioning embeddings for cross attention
|
717 |
+
if encoder_hidden_states is not None:
|
718 |
+
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
719 |
+
|
720 |
+
sample = rearrange(sample, "b c f h w -> (b f) c h w")
|
721 |
+
|
722 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
723 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
724 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
725 |
+
# on the fly if necessary.
|
726 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
727 |
+
|
728 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
729 |
+
forward_upsample_size = False
|
730 |
+
upsample_size = None
|
731 |
+
|
732 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
733 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
734 |
+
forward_upsample_size = True
|
735 |
+
|
736 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
737 |
+
# expects mask of shape:
|
738 |
+
# [batch, key_tokens]
|
739 |
+
# adds singleton query_tokens dimension:
|
740 |
+
# [batch, 1, key_tokens]
|
741 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
742 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
743 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
744 |
+
if attention_mask is not None:
|
745 |
+
# assume that mask is expressed as:
|
746 |
+
# (1 = keep, 0 = discard)
|
747 |
+
# convert mask into a bias that can be added to attention scores:
|
748 |
+
# (keep = +0, discard = -10000.0)
|
749 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
750 |
+
attention_mask = attention_mask.unsqueeze(1)
|
751 |
+
|
752 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
753 |
+
if encoder_attention_mask is not None:
|
754 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
755 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
756 |
+
|
757 |
+
# 0. center input if necessary
|
758 |
+
if self.config.center_input_sample:
|
759 |
+
sample = 2 * sample - 1.0
|
760 |
+
|
761 |
+
# 1. time
|
762 |
+
timesteps = timestep
|
763 |
+
if not torch.is_tensor(timesteps):
|
764 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
765 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
766 |
+
is_mps = sample.device.type == "mps"
|
767 |
+
if isinstance(timestep, float):
|
768 |
+
dtype = torch.float32 if is_mps else torch.float64
|
769 |
+
else:
|
770 |
+
dtype = torch.int32 if is_mps else torch.int64
|
771 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
772 |
+
elif len(timesteps.shape) == 0:
|
773 |
+
timesteps = timesteps[None].to(sample.device)
|
774 |
+
|
775 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
776 |
+
timesteps = timesteps.expand(sample.shape[0])
|
777 |
+
|
778 |
+
t_emb = self.time_proj(timesteps)
|
779 |
+
|
780 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
781 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
782 |
+
# there might be better ways to encapsulate this.
|
783 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
784 |
+
|
785 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
786 |
+
|
787 |
+
if self.use_frame_stride_condition:
|
788 |
+
if not torch.is_tensor(frame_stride):
|
789 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
790 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
791 |
+
is_mps = sample.device.type == "mps"
|
792 |
+
if isinstance(timestep, float):
|
793 |
+
dtype = torch.float32 if is_mps else torch.float64
|
794 |
+
else:
|
795 |
+
dtype = torch.int32 if is_mps else torch.int64
|
796 |
+
frame_stride = torch.tensor([frame_stride], dtype=dtype, device=sample.device)
|
797 |
+
elif len(frame_stride.shape) == 0:
|
798 |
+
frame_stride = frame_stride[None].to(sample.device)
|
799 |
+
|
800 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
801 |
+
frame_stride = frame_stride.expand(sample.shape[0])
|
802 |
+
|
803 |
+
fs_emb = self.time_proj(frame_stride)
|
804 |
+
|
805 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
806 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
807 |
+
# there might be better ways to encapsulate this.
|
808 |
+
fs_emb = fs_emb.to(dtype=sample.dtype)
|
809 |
+
|
810 |
+
fs_emb = self.frame_stride_embedding(fs_emb, timestep_cond)
|
811 |
+
emb = emb + fs_emb
|
812 |
+
|
813 |
+
aug_emb = None
|
814 |
+
|
815 |
+
if self.class_embedding is not None:
|
816 |
+
if class_labels is None:
|
817 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
818 |
+
|
819 |
+
if self.config.class_embed_type == "timestep":
|
820 |
+
class_labels = self.time_proj(class_labels)
|
821 |
+
|
822 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
823 |
+
# there might be better ways to encapsulate this.
|
824 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
825 |
+
|
826 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
827 |
+
|
828 |
+
if self.config.class_embeddings_concat:
|
829 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
830 |
+
else:
|
831 |
+
emb = emb + class_emb
|
832 |
+
|
833 |
+
if self.config.addition_embed_type == "text":
|
834 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
835 |
+
elif self.config.addition_embed_type == "text_image":
|
836 |
+
# Kandinsky 2.1 - style
|
837 |
+
if "image_embeds" not in added_cond_kwargs:
|
838 |
+
raise ValueError(
|
839 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
840 |
+
)
|
841 |
+
|
842 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
843 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
844 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
845 |
+
elif self.config.addition_embed_type == "text_time":
|
846 |
+
# SDXL - style
|
847 |
+
if "text_embeds" not in added_cond_kwargs:
|
848 |
+
raise ValueError(
|
849 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
850 |
+
)
|
851 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
852 |
+
if "time_ids" not in added_cond_kwargs:
|
853 |
+
raise ValueError(
|
854 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
855 |
+
)
|
856 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
857 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
858 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
859 |
+
|
860 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
861 |
+
add_embeds = add_embeds.to(emb.dtype)
|
862 |
+
aug_emb = self.add_embedding(add_embeds)
|
863 |
+
elif self.config.addition_embed_type == "image":
|
864 |
+
# Kandinsky 2.2 - style
|
865 |
+
if "image_embeds" not in added_cond_kwargs:
|
866 |
+
raise ValueError(
|
867 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
868 |
+
)
|
869 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
870 |
+
aug_emb = self.add_embedding(image_embs)
|
871 |
+
elif self.config.addition_embed_type == "image_hint":
|
872 |
+
# Kandinsky 2.2 - style
|
873 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
874 |
+
raise ValueError(
|
875 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
876 |
+
)
|
877 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
878 |
+
hint = added_cond_kwargs.get("hint")
|
879 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
880 |
+
sample = torch.cat([sample, hint], dim=1)
|
881 |
+
|
882 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
883 |
+
|
884 |
+
if self.time_embed_act is not None:
|
885 |
+
emb = self.time_embed_act(emb)
|
886 |
+
|
887 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
888 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
889 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
890 |
+
# Kadinsky 2.1 - style
|
891 |
+
if "image_embeds" not in added_cond_kwargs:
|
892 |
+
raise ValueError(
|
893 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
894 |
+
)
|
895 |
+
|
896 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
897 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
898 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
899 |
+
# Kandinsky 2.2 - style
|
900 |
+
if "image_embeds" not in added_cond_kwargs:
|
901 |
+
raise ValueError(
|
902 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
903 |
+
)
|
904 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
905 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
906 |
+
# 2. pre-process
|
907 |
+
sample = self.conv_in(sample)
|
908 |
+
|
909 |
+
# 2.5 GLIGEN position net
|
910 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
911 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
912 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
913 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
914 |
+
|
915 |
+
# 3. down
|
916 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
917 |
+
|
918 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
919 |
+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
920 |
+
|
921 |
+
down_block_res_samples = (sample,)
|
922 |
+
for downsample_block in self.down_blocks:
|
923 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
924 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
925 |
+
additional_residuals = {}
|
926 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
927 |
+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
928 |
+
|
929 |
+
sample, res_samples = downsample_block(
|
930 |
+
hidden_states=sample,
|
931 |
+
temb=emb,
|
932 |
+
encoder_hidden_states=encoder_hidden_states,
|
933 |
+
attention_mask=attention_mask,
|
934 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
935 |
+
encoder_attention_mask=encoder_attention_mask,
|
936 |
+
first_frame_latents=first_frame_latents,
|
937 |
+
**additional_residuals,
|
938 |
+
)
|
939 |
+
else:
|
940 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, first_frame_latents=first_frame_latents,)
|
941 |
+
|
942 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
943 |
+
sample += down_block_additional_residuals.pop(0)
|
944 |
+
|
945 |
+
down_block_res_samples += res_samples
|
946 |
+
|
947 |
+
if is_controlnet:
|
948 |
+
new_down_block_res_samples = ()
|
949 |
+
|
950 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
951 |
+
down_block_res_samples, down_block_additional_residuals
|
952 |
+
):
|
953 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
954 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
955 |
+
|
956 |
+
down_block_res_samples = new_down_block_res_samples
|
957 |
+
|
958 |
+
# 4. mid
|
959 |
+
if self.mid_block is not None:
|
960 |
+
sample = self.mid_block(
|
961 |
+
sample,
|
962 |
+
emb,
|
963 |
+
encoder_hidden_states=encoder_hidden_states,
|
964 |
+
attention_mask=attention_mask,
|
965 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
966 |
+
encoder_attention_mask=encoder_attention_mask,
|
967 |
+
# additional
|
968 |
+
first_frame_latents=first_frame_latents,
|
969 |
+
)
|
970 |
+
# To support T2I-Adapter-XL
|
971 |
+
if (
|
972 |
+
is_adapter
|
973 |
+
and len(down_block_additional_residuals) > 0
|
974 |
+
and sample.shape == down_block_additional_residuals[0].shape
|
975 |
+
):
|
976 |
+
sample += down_block_additional_residuals.pop(0)
|
977 |
+
|
978 |
+
if is_controlnet:
|
979 |
+
sample = sample + mid_block_additional_residual
|
980 |
+
|
981 |
+
# 5. up
|
982 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
983 |
+
is_final_block = i == len(self.up_blocks) - 1
|
984 |
+
|
985 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
986 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
987 |
+
|
988 |
+
# if we have not reached the final block and need to forward the
|
989 |
+
# upsample size, we do it here
|
990 |
+
if not is_final_block and forward_upsample_size:
|
991 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
992 |
+
|
993 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
994 |
+
sample = upsample_block(
|
995 |
+
hidden_states=sample,
|
996 |
+
temb=emb,
|
997 |
+
res_hidden_states_tuple=res_samples,
|
998 |
+
encoder_hidden_states=encoder_hidden_states,
|
999 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1000 |
+
upsample_size=upsample_size,
|
1001 |
+
attention_mask=attention_mask,
|
1002 |
+
encoder_attention_mask=encoder_attention_mask,
|
1003 |
+
first_frame_latents=first_frame_latents,
|
1004 |
+
)
|
1005 |
+
else:
|
1006 |
+
sample = upsample_block(
|
1007 |
+
hidden_states=sample,
|
1008 |
+
temb=emb,
|
1009 |
+
res_hidden_states_tuple=res_samples,
|
1010 |
+
upsample_size=upsample_size,
|
1011 |
+
scale=lora_scale,
|
1012 |
+
first_frame_latents=first_frame_latents,
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
# 6. post-process
|
1016 |
+
if self.conv_norm_out:
|
1017 |
+
sample = self.conv_norm_out(sample)
|
1018 |
+
sample = self.conv_act(sample)
|
1019 |
+
sample = self.conv_out(sample)
|
1020 |
+
|
1021 |
+
sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length)
|
1022 |
+
if self.config.first_frame_condition_mode != "none":
|
1023 |
+
sample = sample[:, :, 1:, :, :]
|
1024 |
+
|
1025 |
+
if not return_dict:
|
1026 |
+
return (sample,)
|
1027 |
+
|
1028 |
+
return UNet2DConditionOutput(sample=sample)
|
1029 |
+
|
1030 |
+
@classmethod
|
1031 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
1032 |
+
|
1033 |
+
kwargs.pop("low_cpu_mem_usage", False)
|
1034 |
+
kwargs.pop("device_map", None)
|
1035 |
+
|
1036 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
1037 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
1038 |
+
force_download = kwargs.pop("force_download", False)
|
1039 |
+
from_flax = kwargs.pop("from_flax", False)
|
1040 |
+
resume_download = kwargs.pop("resume_download", False)
|
1041 |
+
proxies = kwargs.pop("proxies", None)
|
1042 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
1043 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
1044 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
1045 |
+
revision = kwargs.pop("revision", None)
|
1046 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
1047 |
+
subfolder = kwargs.pop("subfolder", None)
|
1048 |
+
device_map = None
|
1049 |
+
max_memory = kwargs.pop("max_memory", None)
|
1050 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
1051 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
1052 |
+
low_cpu_mem_usage = False
|
1053 |
+
variant = kwargs.pop("variant", None)
|
1054 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1055 |
+
|
1056 |
+
allow_pickle = False
|
1057 |
+
if use_safetensors is None:
|
1058 |
+
use_safetensors = True
|
1059 |
+
allow_pickle = True
|
1060 |
+
|
1061 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
1062 |
+
low_cpu_mem_usage = False
|
1063 |
+
logger.warning(
|
1064 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
1065 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
1066 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
1067 |
+
" install accelerate\n```\n."
|
1068 |
+
)
|
1069 |
+
|
1070 |
+
if device_map is not None and not is_accelerate_available():
|
1071 |
+
raise NotImplementedError(
|
1072 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
1073 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
1074 |
+
)
|
1075 |
+
|
1076 |
+
# Check if we can handle device_map and dispatching the weights
|
1077 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
1078 |
+
raise NotImplementedError(
|
1079 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1080 |
+
" `device_map=None`."
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
1084 |
+
raise NotImplementedError(
|
1085 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1086 |
+
" `low_cpu_mem_usage=False`."
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
1090 |
+
raise ValueError(
|
1091 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
1092 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
1093 |
+
)
|
1094 |
+
|
1095 |
+
# Load config if we don't provide a configuration
|
1096 |
+
config_path = pretrained_model_name_or_path
|
1097 |
+
|
1098 |
+
user_agent = {
|
1099 |
+
"diffusers": __version__,
|
1100 |
+
"file_type": "model",
|
1101 |
+
"framework": "pytorch",
|
1102 |
+
}
|
1103 |
+
|
1104 |
+
# load config
|
1105 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
1106 |
+
config_path,
|
1107 |
+
cache_dir=cache_dir,
|
1108 |
+
return_unused_kwargs=True,
|
1109 |
+
return_commit_hash=True,
|
1110 |
+
force_download=force_download,
|
1111 |
+
resume_download=resume_download,
|
1112 |
+
proxies=proxies,
|
1113 |
+
local_files_only=local_files_only,
|
1114 |
+
use_auth_token=use_auth_token,
|
1115 |
+
revision=revision,
|
1116 |
+
subfolder=subfolder,
|
1117 |
+
device_map=device_map,
|
1118 |
+
max_memory=max_memory,
|
1119 |
+
offload_folder=offload_folder,
|
1120 |
+
offload_state_dict=offload_state_dict,
|
1121 |
+
user_agent=user_agent,
|
1122 |
+
**kwargs,
|
1123 |
+
)
|
1124 |
+
|
1125 |
+
# load model
|
1126 |
+
model_file = None
|
1127 |
+
if from_flax:
|
1128 |
+
model_file = _get_model_file(
|
1129 |
+
pretrained_model_name_or_path,
|
1130 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
1131 |
+
cache_dir=cache_dir,
|
1132 |
+
force_download=force_download,
|
1133 |
+
resume_download=resume_download,
|
1134 |
+
proxies=proxies,
|
1135 |
+
local_files_only=local_files_only,
|
1136 |
+
use_auth_token=use_auth_token,
|
1137 |
+
revision=revision,
|
1138 |
+
subfolder=subfolder,
|
1139 |
+
user_agent=user_agent,
|
1140 |
+
commit_hash=commit_hash,
|
1141 |
+
)
|
1142 |
+
model = cls.from_config(config, **unused_kwargs)
|
1143 |
+
|
1144 |
+
# Convert the weights
|
1145 |
+
from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
1146 |
+
|
1147 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
1148 |
+
else:
|
1149 |
+
if use_safetensors:
|
1150 |
+
try:
|
1151 |
+
model_file = _get_model_file(
|
1152 |
+
pretrained_model_name_or_path,
|
1153 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
1154 |
+
cache_dir=cache_dir,
|
1155 |
+
force_download=force_download,
|
1156 |
+
resume_download=resume_download,
|
1157 |
+
proxies=proxies,
|
1158 |
+
local_files_only=local_files_only,
|
1159 |
+
use_auth_token=use_auth_token,
|
1160 |
+
revision=revision,
|
1161 |
+
subfolder=subfolder,
|
1162 |
+
user_agent=user_agent,
|
1163 |
+
commit_hash=commit_hash,
|
1164 |
+
)
|
1165 |
+
except IOError as e:
|
1166 |
+
if not allow_pickle:
|
1167 |
+
raise e
|
1168 |
+
pass
|
1169 |
+
if model_file is None:
|
1170 |
+
model_file = _get_model_file(
|
1171 |
+
pretrained_model_name_or_path,
|
1172 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
1173 |
+
cache_dir=cache_dir,
|
1174 |
+
force_download=force_download,
|
1175 |
+
resume_download=resume_download,
|
1176 |
+
proxies=proxies,
|
1177 |
+
local_files_only=local_files_only,
|
1178 |
+
use_auth_token=use_auth_token,
|
1179 |
+
revision=revision,
|
1180 |
+
subfolder=subfolder,
|
1181 |
+
user_agent=user_agent,
|
1182 |
+
commit_hash=commit_hash,
|
1183 |
+
)
|
1184 |
+
|
1185 |
+
if low_cpu_mem_usage:
|
1186 |
+
# Instantiate model with empty weights
|
1187 |
+
with accelerate.init_empty_weights():
|
1188 |
+
model = cls.from_config(config, **unused_kwargs)
|
1189 |
+
|
1190 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
1191 |
+
if device_map is None:
|
1192 |
+
param_device = "cpu"
|
1193 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
1194 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
1195 |
+
# move the params from meta device to cpu
|
1196 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
1197 |
+
if len(missing_keys) > 0:
|
1198 |
+
raise ValueError(
|
1199 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
1200 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
1201 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
1202 |
+
" those weights or else make sure your checkpoint file is correct."
|
1203 |
+
)
|
1204 |
+
|
1205 |
+
unexpected_keys = load_model_dict_into_meta(
|
1206 |
+
model,
|
1207 |
+
state_dict,
|
1208 |
+
device=param_device,
|
1209 |
+
dtype=torch_dtype,
|
1210 |
+
model_name_or_path=pretrained_model_name_or_path,
|
1211 |
+
)
|
1212 |
+
|
1213 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
1214 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
1215 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1216 |
+
|
1217 |
+
if len(unexpected_keys) > 0:
|
1218 |
+
logger.warn(
|
1219 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1220 |
+
)
|
1221 |
+
|
1222 |
+
else: # else let accelerate handle loading and dispatching.
|
1223 |
+
# Load weights and dispatch according to the device_map
|
1224 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
1225 |
+
try:
|
1226 |
+
accelerate.load_checkpoint_and_dispatch(
|
1227 |
+
model,
|
1228 |
+
model_file,
|
1229 |
+
device_map,
|
1230 |
+
max_memory=max_memory,
|
1231 |
+
offload_folder=offload_folder,
|
1232 |
+
offload_state_dict=offload_state_dict,
|
1233 |
+
dtype=torch_dtype,
|
1234 |
+
)
|
1235 |
+
except AttributeError as e:
|
1236 |
+
# When using accelerate loading, we do not have the ability to load the state
|
1237 |
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
1238 |
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
1239 |
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
1240 |
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
1241 |
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
1242 |
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
1243 |
+
# the weights so we don't have to do this again.
|
1244 |
+
|
1245 |
+
if "'Attention' object has no attribute" in str(e):
|
1246 |
+
logger.warn(
|
1247 |
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
1248 |
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
1249 |
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
1250 |
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
1251 |
+
" please also re-upload it or open a PR on the original repository."
|
1252 |
+
)
|
1253 |
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
1254 |
+
accelerate.load_checkpoint_and_dispatch(
|
1255 |
+
model,
|
1256 |
+
model_file,
|
1257 |
+
device_map,
|
1258 |
+
max_memory=max_memory,
|
1259 |
+
offload_folder=offload_folder,
|
1260 |
+
offload_state_dict=offload_state_dict,
|
1261 |
+
dtype=torch_dtype,
|
1262 |
+
)
|
1263 |
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
1264 |
+
else:
|
1265 |
+
raise e
|
1266 |
+
|
1267 |
+
loading_info = {
|
1268 |
+
"missing_keys": [],
|
1269 |
+
"unexpected_keys": [],
|
1270 |
+
"mismatched_keys": [],
|
1271 |
+
"error_msgs": [],
|
1272 |
+
}
|
1273 |
+
else:
|
1274 |
+
model = cls.from_config(config, **unused_kwargs)
|
1275 |
+
|
1276 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
1277 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
1278 |
+
|
1279 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
1280 |
+
model,
|
1281 |
+
state_dict,
|
1282 |
+
model_file,
|
1283 |
+
pretrained_model_name_or_path,
|
1284 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
1285 |
+
)
|
1286 |
+
|
1287 |
+
loading_info = {
|
1288 |
+
"missing_keys": missing_keys,
|
1289 |
+
"unexpected_keys": unexpected_keys,
|
1290 |
+
"mismatched_keys": mismatched_keys,
|
1291 |
+
"error_msgs": error_msgs,
|
1292 |
+
}
|
1293 |
+
|
1294 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
1295 |
+
raise ValueError(
|
1296 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
1297 |
+
)
|
1298 |
+
elif torch_dtype is not None:
|
1299 |
+
model = model.to(torch_dtype)
|
1300 |
+
|
1301 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
1302 |
+
|
1303 |
+
m, u = loading_info["missing_keys"], loading_info["unexpected_keys"]
|
1304 |
+
logger.info(f"### missing keys: {len(m)}; unexpected keys: {len(u)};")
|
1305 |
+
# print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
|
1306 |
+
|
1307 |
+
spatial_params = [p.numel() if "conv3ds" not in n and "tempo_attns" not in n else 0 for n, p in model.named_parameters()]
|
1308 |
+
tconv_params = [p.numel() if "conv3ds." in n else 0 for n, p in model.named_parameters()]
|
1309 |
+
tattn_params = [p.numel() if "tempo_attns." in n else 0 for n, p in model.named_parameters()]
|
1310 |
+
tffconv_params = [p.numel() if "first_frame_conv." in n else 0 for n, p in model.named_parameters()]
|
1311 |
+
logger.info(f"### First Frame Convolution Layer Parameters: {sum(tffconv_params) / 1e6} M")
|
1312 |
+
logger.info(f"### Spatial UNet Parameters: {sum(spatial_params) / 1e6} M")
|
1313 |
+
logger.info(f"### Temporal Convolution Module Parameters: {sum(tconv_params) / 1e6} M")
|
1314 |
+
logger.info(f"### Temporal Attention Module Parameters: {sum(tattn_params) / 1e6} M")
|
1315 |
+
|
1316 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
1317 |
+
model.eval()
|
1318 |
+
if output_loading_info:
|
1319 |
+
return model, loading_info
|
1320 |
+
|
1321 |
+
return model
|
1322 |
+
|
1323 |
+
if __name__ == "__main__":
|
1324 |
+
# test
|
1325 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
1326 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
1327 |
+
from consisti2v.pipelines.pipeline_animation import AnimationPipeline
|
1328 |
+
from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline
|
1329 |
+
from consisti2v.utils.util import save_videos_grid
|
1330 |
+
|
1331 |
+
pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5"
|
1332 |
+
prompt = "apply eye makeup"
|
1333 |
+
first_frame_path = "/ML-A100/home/weiming/datasets/UCF/frames/v_ApplyEyeMakeup_g01_c01_frame_90.jpg"
|
1334 |
+
|
1335 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer", use_safetensors=True)
|
1336 |
+
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
|
1337 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae", use_safetensors=True)
|
1338 |
+
unet = VideoLDMUNet3DConditionModel.from_pretrained(
|
1339 |
+
pretrained_model_path,
|
1340 |
+
subfolder="unet",
|
1341 |
+
use_safetensors=True
|
1342 |
+
)
|
1343 |
+
|
1344 |
+
noise_scheduler_kwargs = {
|
1345 |
+
"num_train_timesteps": 1000,
|
1346 |
+
"beta_start": 0.00085,
|
1347 |
+
"beta_end": 0.012,
|
1348 |
+
"beta_schedule": "linear",
|
1349 |
+
"steps_offset": 1,
|
1350 |
+
"clip_sample": False,
|
1351 |
+
}
|
1352 |
+
noise_scheduler = DDIMScheduler(**noise_scheduler_kwargs)
|
1353 |
+
# latent = torch.randn(1, 4, 8, 64, 64).to("cuda")
|
1354 |
+
# text_embedding = torch.randn(1, 77, 768).to("cuda")
|
1355 |
+
# timestep = torch.randint(0, 1000, (1,)).to("cuda").squeeze(0)
|
1356 |
+
# output = unet(latent, timestep, text_embedding)
|
1357 |
+
|
1358 |
+
pipeline = ConditionalAnimationPipeline(
|
1359 |
+
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
|
1360 |
+
).to("cuda")
|
1361 |
+
sample = pipeline(
|
1362 |
+
prompt,
|
1363 |
+
num_inference_steps = 25,
|
1364 |
+
guidance_scale = 8.,
|
1365 |
+
video_length = 8,
|
1366 |
+
height = 256,
|
1367 |
+
width = 256,
|
1368 |
+
first_frame_paths = first_frame_path,
|
1369 |
+
).videos
|
1370 |
+
print(sample.shape)
|
1371 |
+
save_videos_grid(sample, f"samples/videoldm.gif")
|
src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet_blocks.py
ADDED
@@ -0,0 +1,1159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Tuple, Any
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from einops.layers.torch import Rearrange
|
8 |
+
from diffusers.utils import logging
|
9 |
+
from diffusers.models.unet_2d_blocks import (
|
10 |
+
DownBlock2D,
|
11 |
+
UpBlock2D
|
12 |
+
)
|
13 |
+
from diffusers.models.resnet import (
|
14 |
+
ResnetBlock2D,
|
15 |
+
Downsample2D,
|
16 |
+
Upsample2D,
|
17 |
+
)
|
18 |
+
from diffusers.models.transformer_2d import Transformer2DModelOutput
|
19 |
+
from diffusers.models.dual_transformer_2d import DualTransformer2DModel
|
20 |
+
from diffusers.models.activations import get_activation
|
21 |
+
from diffusers.utils import logging, is_torch_version
|
22 |
+
from diffusers.utils.import_utils import is_xformers_available
|
23 |
+
from .videoldm_transformer_blocks import Transformer2DConditionModel
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
if is_xformers_available():
|
28 |
+
import xformers
|
29 |
+
import xformers.ops
|
30 |
+
else:
|
31 |
+
xformers = None
|
32 |
+
|
33 |
+
|
34 |
+
def get_down_block(
|
35 |
+
down_block_type,
|
36 |
+
num_layers,
|
37 |
+
in_channels,
|
38 |
+
out_channels,
|
39 |
+
temb_channels,
|
40 |
+
add_downsample,
|
41 |
+
resnet_eps,
|
42 |
+
resnet_act_fn,
|
43 |
+
transformer_layers_per_block=1,
|
44 |
+
num_attention_heads=None,
|
45 |
+
resnet_groups=None,
|
46 |
+
cross_attention_dim=None,
|
47 |
+
downsample_padding=None,
|
48 |
+
dual_cross_attention=False,
|
49 |
+
use_linear_projection=False,
|
50 |
+
only_cross_attention=False,
|
51 |
+
upcast_attention=False,
|
52 |
+
resnet_time_scale_shift="default",
|
53 |
+
attention_type="default",
|
54 |
+
resnet_skip_time_act=False,
|
55 |
+
resnet_out_scale_factor=1.0,
|
56 |
+
cross_attention_norm=None,
|
57 |
+
attention_head_dim=None,
|
58 |
+
downsample_type=None,
|
59 |
+
dropout=0.0,
|
60 |
+
# additional
|
61 |
+
use_temporal=True,
|
62 |
+
augment_temporal_attention=False,
|
63 |
+
n_frames=8,
|
64 |
+
n_temp_heads=8,
|
65 |
+
first_frame_condition_mode="none",
|
66 |
+
latent_channels=4,
|
67 |
+
rotary_emb=False,
|
68 |
+
):
|
69 |
+
# If attn head dim is not defined, we default it to the number of heads
|
70 |
+
if attention_head_dim is None:
|
71 |
+
logger.warn(
|
72 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
73 |
+
)
|
74 |
+
attention_head_dim = num_attention_heads
|
75 |
+
|
76 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
77 |
+
if down_block_type == "DownBlock2D":
|
78 |
+
return VideoLDMDownBlock(
|
79 |
+
num_layers=num_layers,
|
80 |
+
in_channels=in_channels,
|
81 |
+
out_channels=out_channels,
|
82 |
+
temb_channels=temb_channels,
|
83 |
+
dropout=dropout,
|
84 |
+
add_downsample=add_downsample,
|
85 |
+
resnet_eps=resnet_eps,
|
86 |
+
resnet_act_fn=resnet_act_fn,
|
87 |
+
resnet_groups=resnet_groups,
|
88 |
+
downsample_padding=downsample_padding,
|
89 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
90 |
+
# additional
|
91 |
+
use_temporal=use_temporal,
|
92 |
+
n_frames=n_frames,
|
93 |
+
first_frame_condition_mode=first_frame_condition_mode,
|
94 |
+
latent_channels=latent_channels
|
95 |
+
)
|
96 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
97 |
+
return VideoLDMCrossAttnDownBlock(
|
98 |
+
num_layers=num_layers,
|
99 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
100 |
+
in_channels=in_channels,
|
101 |
+
out_channels=out_channels,
|
102 |
+
temb_channels=temb_channels,
|
103 |
+
dropout=dropout,
|
104 |
+
add_downsample=add_downsample,
|
105 |
+
resnet_eps=resnet_eps,
|
106 |
+
resnet_act_fn=resnet_act_fn,
|
107 |
+
resnet_groups=resnet_groups,
|
108 |
+
downsample_padding=downsample_padding,
|
109 |
+
cross_attention_dim=cross_attention_dim,
|
110 |
+
num_attention_heads=num_attention_heads,
|
111 |
+
dual_cross_attention=dual_cross_attention,
|
112 |
+
use_linear_projection=use_linear_projection,
|
113 |
+
only_cross_attention=only_cross_attention,
|
114 |
+
upcast_attention=upcast_attention,
|
115 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
116 |
+
attention_type=attention_type,
|
117 |
+
# additional
|
118 |
+
use_temporal=use_temporal,
|
119 |
+
augment_temporal_attention=augment_temporal_attention,
|
120 |
+
n_frames=n_frames,
|
121 |
+
n_temp_heads=n_temp_heads,
|
122 |
+
first_frame_condition_mode=first_frame_condition_mode,
|
123 |
+
latent_channels=latent_channels,
|
124 |
+
rotary_emb=rotary_emb,
|
125 |
+
)
|
126 |
+
|
127 |
+
raise ValueError(f'{down_block_type} does not exist.')
|
128 |
+
|
129 |
+
|
130 |
+
def get_up_block(
|
131 |
+
up_block_type,
|
132 |
+
num_layers,
|
133 |
+
in_channels,
|
134 |
+
out_channels,
|
135 |
+
prev_output_channel,
|
136 |
+
temb_channels,
|
137 |
+
add_upsample,
|
138 |
+
resnet_eps,
|
139 |
+
resnet_act_fn,
|
140 |
+
transformer_layers_per_block=1,
|
141 |
+
num_attention_heads=None,
|
142 |
+
resnet_groups=None,
|
143 |
+
cross_attention_dim=None,
|
144 |
+
dual_cross_attention=False,
|
145 |
+
use_linear_projection=False,
|
146 |
+
only_cross_attention=False,
|
147 |
+
upcast_attention=False,
|
148 |
+
resnet_time_scale_shift="default",
|
149 |
+
attention_type="default",
|
150 |
+
resnet_skip_time_act=False,
|
151 |
+
resnet_out_scale_factor=1.0,
|
152 |
+
cross_attention_norm=None,
|
153 |
+
attention_head_dim=None,
|
154 |
+
upsample_type=None,
|
155 |
+
dropout=0.0,
|
156 |
+
# additional
|
157 |
+
use_temporal=True,
|
158 |
+
augment_temporal_attention=False,
|
159 |
+
n_frames=8,
|
160 |
+
n_temp_heads=8,
|
161 |
+
first_frame_condition_mode="none",
|
162 |
+
latent_channels=4,
|
163 |
+
rotary_emb=None,
|
164 |
+
):
|
165 |
+
if attention_head_dim is None:
|
166 |
+
logger.warn(
|
167 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
168 |
+
)
|
169 |
+
attention_head_dim = num_attention_heads
|
170 |
+
|
171 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
172 |
+
if up_block_type == "UpBlock2D":
|
173 |
+
return VideoLDMUpBlock(
|
174 |
+
num_layers=num_layers,
|
175 |
+
in_channels=in_channels,
|
176 |
+
out_channels=out_channels,
|
177 |
+
prev_output_channel=prev_output_channel,
|
178 |
+
temb_channels=temb_channels,
|
179 |
+
dropout=dropout,
|
180 |
+
add_upsample=add_upsample,
|
181 |
+
resnet_eps=resnet_eps,
|
182 |
+
resnet_act_fn=resnet_act_fn,
|
183 |
+
resnet_groups=resnet_groups,
|
184 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
185 |
+
# additional
|
186 |
+
use_temporal=use_temporal,
|
187 |
+
n_frames=n_frames,
|
188 |
+
first_frame_condition_mode=first_frame_condition_mode,
|
189 |
+
latent_channels=latent_channels
|
190 |
+
)
|
191 |
+
elif up_block_type == 'CrossAttnUpBlock2D':
|
192 |
+
return VideoLDMCrossAttnUpBlock(
|
193 |
+
num_layers=num_layers,
|
194 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
195 |
+
in_channels=in_channels,
|
196 |
+
out_channels=out_channels,
|
197 |
+
prev_output_channel=prev_output_channel,
|
198 |
+
temb_channels=temb_channels,
|
199 |
+
dropout=dropout,
|
200 |
+
add_upsample=add_upsample,
|
201 |
+
resnet_eps=resnet_eps,
|
202 |
+
resnet_act_fn=resnet_act_fn,
|
203 |
+
resnet_groups=resnet_groups,
|
204 |
+
cross_attention_dim=cross_attention_dim,
|
205 |
+
num_attention_heads=num_attention_heads,
|
206 |
+
dual_cross_attention=dual_cross_attention,
|
207 |
+
use_linear_projection=use_linear_projection,
|
208 |
+
only_cross_attention=only_cross_attention,
|
209 |
+
upcast_attention=upcast_attention,
|
210 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
211 |
+
attention_type=attention_type,
|
212 |
+
# additional
|
213 |
+
use_temporal=use_temporal,
|
214 |
+
augment_temporal_attention=augment_temporal_attention,
|
215 |
+
n_frames=n_frames,
|
216 |
+
n_temp_heads=n_temp_heads,
|
217 |
+
first_frame_condition_mode=first_frame_condition_mode,
|
218 |
+
latent_channels=latent_channels,
|
219 |
+
rotary_emb=rotary_emb,
|
220 |
+
)
|
221 |
+
|
222 |
+
raise ValueError(f'{up_block_type} does not exist.')
|
223 |
+
|
224 |
+
|
225 |
+
class TemporalResnetBlock(nn.Module):
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
*,
|
229 |
+
in_channels,
|
230 |
+
out_channels=None,
|
231 |
+
dropout=0.0,
|
232 |
+
temb_channels=512,
|
233 |
+
groups=32,
|
234 |
+
groups_out=None,
|
235 |
+
pre_norm=True,
|
236 |
+
eps=1e-6,
|
237 |
+
non_linearity="swish",
|
238 |
+
time_embedding_norm="default",
|
239 |
+
output_scale_factor=1.0,
|
240 |
+
# additional
|
241 |
+
n_frames=8,
|
242 |
+
):
|
243 |
+
super().__init__()
|
244 |
+
self.pre_norm = pre_norm
|
245 |
+
self.pre_norm = True
|
246 |
+
self.in_channels = in_channels
|
247 |
+
out_channels = in_channels if out_channels is None else out_channels
|
248 |
+
self.out_channels = out_channels
|
249 |
+
self.time_embedding_norm = time_embedding_norm
|
250 |
+
self.output_scale_factor = output_scale_factor
|
251 |
+
|
252 |
+
if groups_out is None:
|
253 |
+
groups_out = groups
|
254 |
+
|
255 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
256 |
+
|
257 |
+
self.conv1 = Conv3DLayer(in_channels, out_channels, n_frames=n_frames)
|
258 |
+
|
259 |
+
if temb_channels is not None:
|
260 |
+
if self.time_embedding_norm == "default":
|
261 |
+
time_emb_proj_out_channels = out_channels
|
262 |
+
elif self.time_embedding_norm == "scale_shift":
|
263 |
+
time_emb_proj_out_channels = out_channels * 2
|
264 |
+
else:
|
265 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
266 |
+
|
267 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
268 |
+
else:
|
269 |
+
self.time_emb_proj = None
|
270 |
+
|
271 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
272 |
+
|
273 |
+
self.dropout = torch.nn.Dropout(dropout)
|
274 |
+
self.conv2 = Conv3DLayer(out_channels, out_channels, n_frames=n_frames)
|
275 |
+
|
276 |
+
self.nonlinearity = get_activation(non_linearity)
|
277 |
+
|
278 |
+
self.alpha = nn.Parameter(torch.ones(1))
|
279 |
+
|
280 |
+
def forward(self, input_tensor, temb=None):
|
281 |
+
hidden_states = input_tensor
|
282 |
+
|
283 |
+
hidden_states = self.norm1(hidden_states)
|
284 |
+
hidden_states = self.nonlinearity(hidden_states)
|
285 |
+
|
286 |
+
hidden_states = self.conv1(hidden_states)
|
287 |
+
|
288 |
+
if temb is not None:
|
289 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
290 |
+
|
291 |
+
if temb is not None and self.time_embedding_norm == "default":
|
292 |
+
hidden_states = hidden_states + temb
|
293 |
+
|
294 |
+
hidden_states = self.norm2(hidden_states)
|
295 |
+
|
296 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
297 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
298 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
299 |
+
|
300 |
+
hidden_states = self.nonlinearity(hidden_states)
|
301 |
+
|
302 |
+
hidden_states = self.dropout(hidden_states)
|
303 |
+
hidden_states = self.conv2(hidden_states)
|
304 |
+
|
305 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
306 |
+
|
307 |
+
# weighted sum between spatial and temporal features
|
308 |
+
with torch.no_grad():
|
309 |
+
self.alpha.clamp_(0, 1)
|
310 |
+
|
311 |
+
output_tensor = self.alpha * input_tensor + (1 - self.alpha) * output_tensor
|
312 |
+
|
313 |
+
return output_tensor
|
314 |
+
|
315 |
+
|
316 |
+
class Conv3DLayer(nn.Conv3d):
|
317 |
+
def __init__(self, in_dim, out_dim, n_frames):
|
318 |
+
k, p = (3, 1, 1), (1, 0, 0)
|
319 |
+
super().__init__(in_channels=in_dim, out_channels=out_dim, kernel_size=k, stride=1, padding=p)
|
320 |
+
|
321 |
+
self.to_3d = Rearrange('(b t) c h w -> b c t h w', t=n_frames)
|
322 |
+
self.to_2d = Rearrange('b c t h w -> (b t) c h w')
|
323 |
+
|
324 |
+
def forward(self, x):
|
325 |
+
h = self.to_3d(x)
|
326 |
+
h = super().forward(h)
|
327 |
+
out = self.to_2d(h)
|
328 |
+
return out
|
329 |
+
|
330 |
+
|
331 |
+
class IdentityLayer(nn.Identity):
|
332 |
+
def __init__(self, return_trans2d_output, *args, **kwargs):
|
333 |
+
super().__init__()
|
334 |
+
self.return_trans2d_output = return_trans2d_output
|
335 |
+
|
336 |
+
def forward(self, x, *args, **kwargs):
|
337 |
+
if self.return_trans2d_output:
|
338 |
+
return Transformer2DModelOutput(sample=x)
|
339 |
+
else:
|
340 |
+
return x
|
341 |
+
|
342 |
+
|
343 |
+
class VideoLDMCrossAttnDownBlock(nn.Module):
|
344 |
+
def __init__(
|
345 |
+
self,
|
346 |
+
in_channels: int,
|
347 |
+
out_channels: int,
|
348 |
+
temb_channels: int,
|
349 |
+
dropout: float = 0.0,
|
350 |
+
num_layers: int = 1,
|
351 |
+
transformer_layers_per_block: int = 1,
|
352 |
+
resnet_eps: float = 1e-6,
|
353 |
+
resnet_time_scale_shift: str = "default",
|
354 |
+
resnet_act_fn: str = "swish",
|
355 |
+
resnet_groups: int = 32,
|
356 |
+
resnet_pre_norm: bool = True,
|
357 |
+
num_attention_heads=1,
|
358 |
+
cross_attention_dim=1280,
|
359 |
+
output_scale_factor=1.0,
|
360 |
+
downsample_padding=1,
|
361 |
+
add_downsample=True,
|
362 |
+
dual_cross_attention=False,
|
363 |
+
use_linear_projection=False,
|
364 |
+
only_cross_attention=False,
|
365 |
+
upcast_attention=False,
|
366 |
+
attention_type="default",
|
367 |
+
# additional
|
368 |
+
use_temporal=True,
|
369 |
+
augment_temporal_attention=False,
|
370 |
+
n_frames=8,
|
371 |
+
n_temp_heads=8,
|
372 |
+
first_frame_condition_mode="none",
|
373 |
+
latent_channels=4,
|
374 |
+
rotary_emb=False,
|
375 |
+
):
|
376 |
+
super().__init__()
|
377 |
+
|
378 |
+
self.use_temporal = use_temporal
|
379 |
+
|
380 |
+
self.n_frames = n_frames
|
381 |
+
self.first_frame_condition_mode = first_frame_condition_mode
|
382 |
+
if self.first_frame_condition_mode == "conv2d":
|
383 |
+
self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1)
|
384 |
+
|
385 |
+
resnets = []
|
386 |
+
attentions = []
|
387 |
+
|
388 |
+
self.n_frames = n_frames
|
389 |
+
self.n_temp_heads = n_temp_heads
|
390 |
+
|
391 |
+
self.has_cross_attention = True
|
392 |
+
self.num_attention_heads = num_attention_heads
|
393 |
+
|
394 |
+
for i in range(num_layers):
|
395 |
+
in_channels = in_channels if i == 0 else out_channels
|
396 |
+
resnets.append(
|
397 |
+
ResnetBlock2D(
|
398 |
+
in_channels=in_channels,
|
399 |
+
out_channels=out_channels,
|
400 |
+
temb_channels=temb_channels,
|
401 |
+
eps=resnet_eps,
|
402 |
+
groups=resnet_groups,
|
403 |
+
dropout=dropout,
|
404 |
+
time_embedding_norm=resnet_time_scale_shift,
|
405 |
+
non_linearity=resnet_act_fn,
|
406 |
+
output_scale_factor=output_scale_factor,
|
407 |
+
pre_norm=resnet_pre_norm,
|
408 |
+
)
|
409 |
+
)
|
410 |
+
if not dual_cross_attention:
|
411 |
+
attentions.append(
|
412 |
+
Transformer2DConditionModel(
|
413 |
+
num_attention_heads,
|
414 |
+
out_channels // num_attention_heads,
|
415 |
+
in_channels=out_channels,
|
416 |
+
num_layers=transformer_layers_per_block,
|
417 |
+
cross_attention_dim=cross_attention_dim,
|
418 |
+
norm_num_groups=resnet_groups,
|
419 |
+
use_linear_projection=use_linear_projection,
|
420 |
+
only_cross_attention=only_cross_attention,
|
421 |
+
upcast_attention=upcast_attention,
|
422 |
+
attention_type=attention_type,
|
423 |
+
# additional
|
424 |
+
n_frames=n_frames,
|
425 |
+
)
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
attentions.append(
|
429 |
+
DualTransformer2DModel(
|
430 |
+
num_attention_heads,
|
431 |
+
out_channels // num_attention_heads,
|
432 |
+
in_channels=out_channels,
|
433 |
+
num_layers=1,
|
434 |
+
cross_attention_dim=cross_attention_dim,
|
435 |
+
norm_num_groups=resnet_groups,
|
436 |
+
)
|
437 |
+
)
|
438 |
+
self.attentions = nn.ModuleList(attentions)
|
439 |
+
self.resnets = nn.ModuleList(resnets)
|
440 |
+
|
441 |
+
if add_downsample:
|
442 |
+
self.downsamplers = nn.ModuleList(
|
443 |
+
[
|
444 |
+
Downsample2D(
|
445 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
446 |
+
)
|
447 |
+
]
|
448 |
+
)
|
449 |
+
else:
|
450 |
+
self.downsamplers = None
|
451 |
+
|
452 |
+
self.gradient_checkpointing = False
|
453 |
+
|
454 |
+
# >>> Temporal Layers >>>
|
455 |
+
conv3ds = []
|
456 |
+
tempo_attns = []
|
457 |
+
|
458 |
+
for i in range(num_layers):
|
459 |
+
if self.use_temporal:
|
460 |
+
conv3ds.append(
|
461 |
+
TemporalResnetBlock(
|
462 |
+
in_channels=out_channels,
|
463 |
+
out_channels=out_channels,
|
464 |
+
n_frames=n_frames,
|
465 |
+
)
|
466 |
+
)
|
467 |
+
|
468 |
+
tempo_attns.append(
|
469 |
+
Transformer2DConditionModel(
|
470 |
+
n_temp_heads,
|
471 |
+
out_channels // n_temp_heads,
|
472 |
+
in_channels=out_channels,
|
473 |
+
num_layers=transformer_layers_per_block,
|
474 |
+
cross_attention_dim=cross_attention_dim,
|
475 |
+
norm_num_groups=resnet_groups,
|
476 |
+
use_linear_projection=use_linear_projection,
|
477 |
+
only_cross_attention=only_cross_attention,
|
478 |
+
upcast_attention=upcast_attention,
|
479 |
+
attention_type=attention_type,
|
480 |
+
# additional
|
481 |
+
n_frames=n_frames,
|
482 |
+
is_temporal=True,
|
483 |
+
augment_temporal_attention=augment_temporal_attention,
|
484 |
+
rotary_emb=rotary_emb
|
485 |
+
)
|
486 |
+
)
|
487 |
+
else:
|
488 |
+
conv3ds.append(IdentityLayer(return_trans2d_output=False))
|
489 |
+
tempo_attns.append(IdentityLayer(return_trans2d_output=True))
|
490 |
+
|
491 |
+
self.conv3ds = nn.ModuleList(conv3ds)
|
492 |
+
self.tempo_attns = nn.ModuleList(tempo_attns)
|
493 |
+
# <<< Temporal Layers <<<
|
494 |
+
|
495 |
+
def forward(
|
496 |
+
self,
|
497 |
+
hidden_states: torch.FloatTensor,
|
498 |
+
temb: Optional[torch.FloatTensor] = None,
|
499 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
500 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
501 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
502 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
503 |
+
# additional
|
504 |
+
first_frame_latents=None,
|
505 |
+
):
|
506 |
+
condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only")
|
507 |
+
# input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
|
508 |
+
if self.first_frame_condition_mode == "conv2d":
|
509 |
+
hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
|
510 |
+
hidden_height = hidden_states.shape[3]
|
511 |
+
first_frame_height = first_frame_latents.shape[3]
|
512 |
+
downsample_ratio = hidden_height / first_frame_height
|
513 |
+
first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
|
514 |
+
first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
|
515 |
+
hidden_states[:, :, 0:1, :, :] = first_frame_latents
|
516 |
+
hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
|
517 |
+
|
518 |
+
output_states = ()
|
519 |
+
|
520 |
+
for resnet, conv3d, attn, tempo_attn in zip(self.resnets, self.conv3ds, self.attentions, self.tempo_attns):
|
521 |
+
|
522 |
+
hidden_states = resnet(hidden_states, temb)
|
523 |
+
hidden_states = conv3d(hidden_states)
|
524 |
+
hidden_states = attn(
|
525 |
+
hidden_states,
|
526 |
+
encoder_hidden_states=encoder_hidden_states,
|
527 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
528 |
+
condition_on_first_frame=condition_on_first_frame,
|
529 |
+
).sample
|
530 |
+
hidden_states = tempo_attn(
|
531 |
+
hidden_states,
|
532 |
+
encoder_hidden_states=encoder_hidden_states,
|
533 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
534 |
+
condition_on_first_frame=False,
|
535 |
+
).sample
|
536 |
+
|
537 |
+
output_states += (hidden_states,)
|
538 |
+
|
539 |
+
if self.downsamplers is not None:
|
540 |
+
for downsampler in self.downsamplers:
|
541 |
+
hidden_states = downsampler(hidden_states)
|
542 |
+
|
543 |
+
output_states += (hidden_states,)
|
544 |
+
|
545 |
+
return hidden_states, output_states
|
546 |
+
|
547 |
+
|
548 |
+
class VideoLDMCrossAttnUpBlock(nn.Module):
|
549 |
+
def __init__(
|
550 |
+
self,
|
551 |
+
in_channels: int,
|
552 |
+
out_channels: int,
|
553 |
+
prev_output_channel: int,
|
554 |
+
temb_channels: int,
|
555 |
+
dropout: float = 0.0,
|
556 |
+
num_layers: int = 1,
|
557 |
+
transformer_layers_per_block: int = 1,
|
558 |
+
resnet_eps: float = 1e-6,
|
559 |
+
resnet_time_scale_shift: str = "default",
|
560 |
+
resnet_act_fn: str = "swish",
|
561 |
+
resnet_groups: int = 32,
|
562 |
+
resnet_pre_norm: bool = True,
|
563 |
+
num_attention_heads=1,
|
564 |
+
cross_attention_dim=1280,
|
565 |
+
output_scale_factor=1.0,
|
566 |
+
add_upsample=True,
|
567 |
+
dual_cross_attention=False,
|
568 |
+
use_linear_projection=False,
|
569 |
+
only_cross_attention=False,
|
570 |
+
upcast_attention=False,
|
571 |
+
attention_type="default",
|
572 |
+
# additional
|
573 |
+
use_temporal=True,
|
574 |
+
augment_temporal_attention=False,
|
575 |
+
n_frames=8,
|
576 |
+
n_temp_heads=8,
|
577 |
+
first_frame_condition_mode="none",
|
578 |
+
latent_channels=4,
|
579 |
+
rotary_emb=False,
|
580 |
+
):
|
581 |
+
super().__init__()
|
582 |
+
|
583 |
+
self.use_temporal = use_temporal
|
584 |
+
|
585 |
+
self.n_frames = n_frames
|
586 |
+
self.first_frame_condition_mode = first_frame_condition_mode
|
587 |
+
if self.first_frame_condition_mode == "conv2d":
|
588 |
+
self.first_frame_conv = nn.Conv2d(latent_channels, prev_output_channel, kernel_size=1)
|
589 |
+
|
590 |
+
resnets = []
|
591 |
+
attentions = []
|
592 |
+
|
593 |
+
self.n_frames = n_frames
|
594 |
+
self.n_temp_heads = n_temp_heads
|
595 |
+
|
596 |
+
self.has_cross_attention = True
|
597 |
+
self.num_attention_heads = num_attention_heads
|
598 |
+
|
599 |
+
for i in range(num_layers):
|
600 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
601 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
602 |
+
|
603 |
+
resnets.append(
|
604 |
+
ResnetBlock2D(
|
605 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
606 |
+
out_channels=out_channels,
|
607 |
+
temb_channels=temb_channels,
|
608 |
+
eps=resnet_eps,
|
609 |
+
groups=resnet_groups,
|
610 |
+
dropout=dropout,
|
611 |
+
time_embedding_norm=resnet_time_scale_shift,
|
612 |
+
non_linearity=resnet_act_fn,
|
613 |
+
output_scale_factor=output_scale_factor,
|
614 |
+
pre_norm=resnet_pre_norm,
|
615 |
+
)
|
616 |
+
)
|
617 |
+
if not dual_cross_attention:
|
618 |
+
attentions.append(
|
619 |
+
Transformer2DConditionModel(
|
620 |
+
num_attention_heads,
|
621 |
+
out_channels // num_attention_heads,
|
622 |
+
in_channels=out_channels,
|
623 |
+
num_layers=transformer_layers_per_block,
|
624 |
+
cross_attention_dim=cross_attention_dim,
|
625 |
+
norm_num_groups=resnet_groups,
|
626 |
+
use_linear_projection=use_linear_projection,
|
627 |
+
only_cross_attention=only_cross_attention,
|
628 |
+
upcast_attention=upcast_attention,
|
629 |
+
attention_type=attention_type,
|
630 |
+
# additional
|
631 |
+
n_frames=n_frames,
|
632 |
+
)
|
633 |
+
)
|
634 |
+
else:
|
635 |
+
attentions.append(
|
636 |
+
DualTransformer2DModel(
|
637 |
+
num_attention_heads,
|
638 |
+
out_channels // num_attention_heads,
|
639 |
+
in_channels=out_channels,
|
640 |
+
num_layers=1,
|
641 |
+
cross_attention_dim=cross_attention_dim,
|
642 |
+
norm_num_groups=resnet_groups,
|
643 |
+
)
|
644 |
+
)
|
645 |
+
self.attentions = nn.ModuleList(attentions)
|
646 |
+
self.resnets = nn.ModuleList(resnets)
|
647 |
+
|
648 |
+
if add_upsample:
|
649 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
650 |
+
else:
|
651 |
+
self.upsamplers = None
|
652 |
+
|
653 |
+
self.gradient_checkpointing = False
|
654 |
+
|
655 |
+
# >>> Temporal Layers >>>
|
656 |
+
conv3ds = []
|
657 |
+
tempo_attns = []
|
658 |
+
|
659 |
+
for i in range(num_layers):
|
660 |
+
if self.use_temporal:
|
661 |
+
conv3ds.append(
|
662 |
+
TemporalResnetBlock(
|
663 |
+
in_channels=out_channels,
|
664 |
+
out_channels=out_channels,
|
665 |
+
n_frames=n_frames,
|
666 |
+
)
|
667 |
+
)
|
668 |
+
|
669 |
+
tempo_attns.append(
|
670 |
+
Transformer2DConditionModel(
|
671 |
+
n_temp_heads,
|
672 |
+
out_channels // n_temp_heads,
|
673 |
+
in_channels=out_channels,
|
674 |
+
num_layers=transformer_layers_per_block,
|
675 |
+
cross_attention_dim=cross_attention_dim,
|
676 |
+
norm_num_groups=resnet_groups,
|
677 |
+
use_linear_projection=use_linear_projection,
|
678 |
+
only_cross_attention=only_cross_attention,
|
679 |
+
upcast_attention=upcast_attention,
|
680 |
+
attention_type=attention_type,
|
681 |
+
# additional
|
682 |
+
n_frames=n_frames,
|
683 |
+
augment_temporal_attention=augment_temporal_attention,
|
684 |
+
is_temporal=True,
|
685 |
+
rotary_emb=rotary_emb,
|
686 |
+
)
|
687 |
+
)
|
688 |
+
else:
|
689 |
+
conv3ds.append(IdentityLayer(return_trans2d_output=False))
|
690 |
+
tempo_attns.append(IdentityLayer(return_trans2d_output=True))
|
691 |
+
|
692 |
+
self.conv3ds = nn.ModuleList(conv3ds)
|
693 |
+
self.tempo_attns = nn.ModuleList(tempo_attns)
|
694 |
+
# <<< Temporal Layers <<<
|
695 |
+
|
696 |
+
def forward(
|
697 |
+
self,
|
698 |
+
hidden_states: torch.FloatTensor,
|
699 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
700 |
+
temb: Optional[torch.FloatTensor] = None,
|
701 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
702 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
703 |
+
upsample_size: Optional[int] = None,
|
704 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
705 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
706 |
+
# additional
|
707 |
+
first_frame_latents=None,
|
708 |
+
):
|
709 |
+
condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only")
|
710 |
+
# input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
|
711 |
+
if self.first_frame_condition_mode == "conv2d":
|
712 |
+
hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
|
713 |
+
hidden_height = hidden_states.shape[3]
|
714 |
+
first_frame_height = first_frame_latents.shape[3]
|
715 |
+
downsample_ratio = hidden_height / first_frame_height
|
716 |
+
first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
|
717 |
+
first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
|
718 |
+
hidden_states[:, :, 0:1, :, :] = first_frame_latents
|
719 |
+
hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
|
720 |
+
|
721 |
+
for resnet, conv3d, attn, tempo_attn in zip(self.resnets, self.conv3ds, self.attentions, self.tempo_attns):
|
722 |
+
|
723 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
724 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
725 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
726 |
+
|
727 |
+
hidden_states = resnet(hidden_states, temb)
|
728 |
+
hidden_states = conv3d(hidden_states)
|
729 |
+
hidden_states = attn(
|
730 |
+
hidden_states,
|
731 |
+
encoder_hidden_states=encoder_hidden_states,
|
732 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
733 |
+
condition_on_first_frame=condition_on_first_frame,
|
734 |
+
).sample
|
735 |
+
hidden_states = tempo_attn(
|
736 |
+
hidden_states,
|
737 |
+
encoder_hidden_states=encoder_hidden_states,
|
738 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
739 |
+
condition_on_first_frame=False,
|
740 |
+
).sample
|
741 |
+
|
742 |
+
if self.upsamplers is not None:
|
743 |
+
for upsampler in self.upsamplers:
|
744 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
745 |
+
return hidden_states
|
746 |
+
|
747 |
+
|
748 |
+
class VideoLDMUNetMidBlock2DCrossAttn(nn.Module):
|
749 |
+
def __init__(
|
750 |
+
self,
|
751 |
+
in_channels: int,
|
752 |
+
temb_channels: int,
|
753 |
+
dropout: float = 0.0,
|
754 |
+
num_layers: int = 1,
|
755 |
+
transformer_layers_per_block: int = 1,
|
756 |
+
resnet_eps: float = 1e-6,
|
757 |
+
resnet_time_scale_shift: str = "default",
|
758 |
+
resnet_act_fn: str = "swish",
|
759 |
+
resnet_groups: int = 32,
|
760 |
+
resnet_pre_norm: bool = True,
|
761 |
+
num_attention_heads=1,
|
762 |
+
output_scale_factor=1.0,
|
763 |
+
cross_attention_dim=1280,
|
764 |
+
dual_cross_attention=False,
|
765 |
+
use_linear_projection=False,
|
766 |
+
upcast_attention=False,
|
767 |
+
attention_type="default",
|
768 |
+
# additional
|
769 |
+
use_temporal=True,
|
770 |
+
n_frames: int = 8,
|
771 |
+
first_frame_condition_mode="none",
|
772 |
+
latent_channels=4,
|
773 |
+
):
|
774 |
+
super().__init__()
|
775 |
+
|
776 |
+
self.use_temporal = use_temporal
|
777 |
+
|
778 |
+
self.n_frames = n_frames
|
779 |
+
self.first_frame_condition_mode = first_frame_condition_mode
|
780 |
+
if self.first_frame_condition_mode == "conv2d":
|
781 |
+
self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1)
|
782 |
+
|
783 |
+
self.has_cross_attention = True
|
784 |
+
self.num_attention_heads = num_attention_heads
|
785 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
786 |
+
|
787 |
+
# there is always at least one resnet
|
788 |
+
resnets = [
|
789 |
+
ResnetBlock2D(
|
790 |
+
in_channels=in_channels,
|
791 |
+
out_channels=in_channels,
|
792 |
+
temb_channels=temb_channels,
|
793 |
+
eps=resnet_eps,
|
794 |
+
groups=resnet_groups,
|
795 |
+
dropout=dropout,
|
796 |
+
time_embedding_norm=resnet_time_scale_shift,
|
797 |
+
non_linearity=resnet_act_fn,
|
798 |
+
output_scale_factor=output_scale_factor,
|
799 |
+
pre_norm=resnet_pre_norm,
|
800 |
+
)
|
801 |
+
]
|
802 |
+
if self.use_temporal:
|
803 |
+
conv3ds = [
|
804 |
+
TemporalResnetBlock(
|
805 |
+
in_channels=in_channels,
|
806 |
+
out_channels=in_channels,
|
807 |
+
n_frames=n_frames,
|
808 |
+
)
|
809 |
+
]
|
810 |
+
else:
|
811 |
+
conv3ds = [IdentityLayer(return_trans2d_output=False)]
|
812 |
+
|
813 |
+
attentions = []
|
814 |
+
|
815 |
+
for _ in range(num_layers):
|
816 |
+
if not dual_cross_attention:
|
817 |
+
attentions.append(
|
818 |
+
Transformer2DConditionModel(
|
819 |
+
num_attention_heads,
|
820 |
+
in_channels // num_attention_heads,
|
821 |
+
in_channels=in_channels,
|
822 |
+
num_layers=transformer_layers_per_block,
|
823 |
+
cross_attention_dim=cross_attention_dim,
|
824 |
+
norm_num_groups=resnet_groups,
|
825 |
+
use_linear_projection=use_linear_projection,
|
826 |
+
upcast_attention=upcast_attention,
|
827 |
+
attention_type=attention_type,
|
828 |
+
# additional
|
829 |
+
n_frames=n_frames,
|
830 |
+
)
|
831 |
+
)
|
832 |
+
else:
|
833 |
+
attentions.append(
|
834 |
+
DualTransformer2DModel(
|
835 |
+
num_attention_heads,
|
836 |
+
in_channels // num_attention_heads,
|
837 |
+
in_channels=in_channels,
|
838 |
+
num_layers=1,
|
839 |
+
cross_attention_dim=cross_attention_dim,
|
840 |
+
norm_num_groups=resnet_groups,
|
841 |
+
)
|
842 |
+
)
|
843 |
+
resnets.append(
|
844 |
+
ResnetBlock2D(
|
845 |
+
in_channels=in_channels,
|
846 |
+
out_channels=in_channels,
|
847 |
+
temb_channels=temb_channels,
|
848 |
+
eps=resnet_eps,
|
849 |
+
groups=resnet_groups,
|
850 |
+
dropout=dropout,
|
851 |
+
time_embedding_norm=resnet_time_scale_shift,
|
852 |
+
non_linearity=resnet_act_fn,
|
853 |
+
output_scale_factor=output_scale_factor,
|
854 |
+
pre_norm=resnet_pre_norm,
|
855 |
+
)
|
856 |
+
)
|
857 |
+
if self.use_temporal:
|
858 |
+
conv3ds.append(
|
859 |
+
TemporalResnetBlock(
|
860 |
+
in_channels=in_channels,
|
861 |
+
out_channels=in_channels,
|
862 |
+
n_frames=n_frames,
|
863 |
+
)
|
864 |
+
)
|
865 |
+
else:
|
866 |
+
conv3ds.append(IdentityLayer(return_trans2d_output=False))
|
867 |
+
|
868 |
+
self.attentions = nn.ModuleList(attentions)
|
869 |
+
self.resnets = nn.ModuleList(resnets)
|
870 |
+
self.conv3ds = nn.ModuleList(conv3ds)
|
871 |
+
|
872 |
+
self.gradient_checkpointing = False
|
873 |
+
|
874 |
+
def forward(
|
875 |
+
self,
|
876 |
+
hidden_states: torch.FloatTensor,
|
877 |
+
temb: Optional[torch.FloatTensor] = None,
|
878 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
879 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
880 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
881 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
882 |
+
# additional
|
883 |
+
first_frame_latents=None,
|
884 |
+
) -> torch.FloatTensor:
|
885 |
+
condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only")
|
886 |
+
# input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
|
887 |
+
if self.first_frame_condition_mode == "conv2d":
|
888 |
+
hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
|
889 |
+
hidden_height = hidden_states.shape[3]
|
890 |
+
first_frame_height = first_frame_latents.shape[3]
|
891 |
+
downsample_ratio = hidden_height / first_frame_height
|
892 |
+
first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
|
893 |
+
first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
|
894 |
+
hidden_states[:, :, 0:1, :, :] = first_frame_latents
|
895 |
+
hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
|
896 |
+
|
897 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
898 |
+
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
899 |
+
hidden_states = self.conv3ds[0](hidden_states)
|
900 |
+
for attn, resnet, conv3d in zip(self.attentions, self.resnets[1:], self.conv3ds[1:]):
|
901 |
+
if self.training and self.gradient_checkpointing:
|
902 |
+
|
903 |
+
def create_custom_forward(module, return_dict=None):
|
904 |
+
def custom_forward(*inputs):
|
905 |
+
if return_dict is not None:
|
906 |
+
return module(*inputs, return_dict=return_dict)
|
907 |
+
else:
|
908 |
+
return module(*inputs)
|
909 |
+
|
910 |
+
return custom_forward
|
911 |
+
|
912 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
913 |
+
hidden_states = attn(
|
914 |
+
hidden_states,
|
915 |
+
encoder_hidden_states=encoder_hidden_states,
|
916 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
917 |
+
attention_mask=attention_mask,
|
918 |
+
encoder_attention_mask=encoder_attention_mask,
|
919 |
+
return_dict=False,
|
920 |
+
# additional
|
921 |
+
condition_on_first_frame=condition_on_first_frame,
|
922 |
+
)[0]
|
923 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
924 |
+
create_custom_forward(resnet),
|
925 |
+
hidden_states,
|
926 |
+
temb,
|
927 |
+
**ckpt_kwargs,
|
928 |
+
)
|
929 |
+
hidden_states = conv3d(hidden_states)
|
930 |
+
else:
|
931 |
+
hidden_states = attn(
|
932 |
+
hidden_states,
|
933 |
+
encoder_hidden_states=encoder_hidden_states,
|
934 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
935 |
+
attention_mask=attention_mask,
|
936 |
+
encoder_attention_mask=encoder_attention_mask,
|
937 |
+
return_dict=False,
|
938 |
+
# additional
|
939 |
+
condition_on_first_frame=condition_on_first_frame,
|
940 |
+
)[0]
|
941 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
942 |
+
hidden_states = conv3d(hidden_states)
|
943 |
+
|
944 |
+
return hidden_states
|
945 |
+
|
946 |
+
|
947 |
+
class VideoLDMDownBlock(DownBlock2D):
|
948 |
+
def __init__(
|
949 |
+
self,
|
950 |
+
in_channels: int,
|
951 |
+
out_channels: int,
|
952 |
+
temb_channels: int,
|
953 |
+
dropout: float = 0.0,
|
954 |
+
num_layers: int = 1,
|
955 |
+
resnet_eps: float = 1e-6,
|
956 |
+
resnet_time_scale_shift: str = "default",
|
957 |
+
resnet_act_fn: str = "swish",
|
958 |
+
resnet_groups: int = 32,
|
959 |
+
resnet_pre_norm: bool = True,
|
960 |
+
output_scale_factor=1.0,
|
961 |
+
add_downsample=True,
|
962 |
+
downsample_padding=1,
|
963 |
+
# additional
|
964 |
+
use_temporal=True,
|
965 |
+
n_frames: int = 8,
|
966 |
+
first_frame_condition_mode="none",
|
967 |
+
latent_channels=4,
|
968 |
+
):
|
969 |
+
super().__init__(
|
970 |
+
in_channels,
|
971 |
+
out_channels,
|
972 |
+
temb_channels,
|
973 |
+
dropout,
|
974 |
+
num_layers,
|
975 |
+
resnet_eps,
|
976 |
+
resnet_time_scale_shift,
|
977 |
+
resnet_act_fn,
|
978 |
+
resnet_groups,
|
979 |
+
resnet_pre_norm,
|
980 |
+
output_scale_factor,
|
981 |
+
add_downsample,
|
982 |
+
downsample_padding,)
|
983 |
+
|
984 |
+
self.use_temporal = use_temporal
|
985 |
+
|
986 |
+
self.n_frames = n_frames
|
987 |
+
self.first_frame_condition_mode = first_frame_condition_mode
|
988 |
+
if self.first_frame_condition_mode == "conv2d":
|
989 |
+
self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1)
|
990 |
+
|
991 |
+
# >>> Temporal Layers >>>
|
992 |
+
conv3ds = []
|
993 |
+
for i in range(num_layers):
|
994 |
+
if self.use_temporal:
|
995 |
+
conv3ds.append(
|
996 |
+
TemporalResnetBlock(
|
997 |
+
in_channels=out_channels,
|
998 |
+
out_channels=out_channels,
|
999 |
+
n_frames=n_frames,
|
1000 |
+
)
|
1001 |
+
)
|
1002 |
+
else:
|
1003 |
+
conv3ds.append(IdentityLayer(return_trans2d_output=False))
|
1004 |
+
self.conv3ds = nn.ModuleList(conv3ds)
|
1005 |
+
# <<< Temporal Layers <<<
|
1006 |
+
|
1007 |
+
def forward(self, hidden_states, temb=None, scale: float = 1, first_frame_latents=None):
|
1008 |
+
# input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
|
1009 |
+
if self.first_frame_condition_mode == "conv2d":
|
1010 |
+
hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
|
1011 |
+
hidden_height = hidden_states.shape[3]
|
1012 |
+
first_frame_height = first_frame_latents.shape[3]
|
1013 |
+
downsample_ratio = hidden_height / first_frame_height
|
1014 |
+
first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
|
1015 |
+
first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
|
1016 |
+
hidden_states[:, :, 0:1, :, :] = first_frame_latents
|
1017 |
+
hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
|
1018 |
+
|
1019 |
+
output_states = ()
|
1020 |
+
|
1021 |
+
for resnet, conv3d in zip(self.resnets, self.conv3ds):
|
1022 |
+
if self.training and self.gradient_checkpointing:
|
1023 |
+
|
1024 |
+
def create_custom_forward(module):
|
1025 |
+
def custom_forward(*inputs):
|
1026 |
+
return module(*inputs)
|
1027 |
+
|
1028 |
+
return custom_forward
|
1029 |
+
|
1030 |
+
if is_torch_version(">=", "1.11.0"):
|
1031 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1032 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
1033 |
+
)
|
1034 |
+
else:
|
1035 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1036 |
+
create_custom_forward(resnet), hidden_states, temb
|
1037 |
+
)
|
1038 |
+
else:
|
1039 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1040 |
+
|
1041 |
+
hidden_states = conv3d(hidden_states)
|
1042 |
+
|
1043 |
+
output_states = output_states + (hidden_states,)
|
1044 |
+
|
1045 |
+
if self.downsamplers is not None:
|
1046 |
+
for downsampler in self.downsamplers:
|
1047 |
+
hidden_states = downsampler(hidden_states, scale=scale)
|
1048 |
+
|
1049 |
+
output_states = output_states + (hidden_states,)
|
1050 |
+
|
1051 |
+
return hidden_states, output_states
|
1052 |
+
|
1053 |
+
|
1054 |
+
class VideoLDMUpBlock(UpBlock2D):
|
1055 |
+
def __init__(
|
1056 |
+
self,
|
1057 |
+
in_channels: int,
|
1058 |
+
prev_output_channel: int,
|
1059 |
+
out_channels: int,
|
1060 |
+
temb_channels: int,
|
1061 |
+
dropout: float = 0.0,
|
1062 |
+
num_layers: int = 1,
|
1063 |
+
resnet_eps: float = 1e-6,
|
1064 |
+
resnet_time_scale_shift: str = "default",
|
1065 |
+
resnet_act_fn: str = "swish",
|
1066 |
+
resnet_groups: int = 32,
|
1067 |
+
resnet_pre_norm: bool = True,
|
1068 |
+
output_scale_factor=1.0,
|
1069 |
+
add_upsample=True,
|
1070 |
+
# additional
|
1071 |
+
use_temporal=True,
|
1072 |
+
n_frames: int = 8,
|
1073 |
+
first_frame_condition_mode="none",
|
1074 |
+
latent_channels=4,
|
1075 |
+
):
|
1076 |
+
super().__init__(
|
1077 |
+
in_channels,
|
1078 |
+
prev_output_channel,
|
1079 |
+
out_channels,
|
1080 |
+
temb_channels,
|
1081 |
+
dropout,
|
1082 |
+
num_layers,
|
1083 |
+
resnet_eps,
|
1084 |
+
resnet_time_scale_shift,
|
1085 |
+
resnet_act_fn,
|
1086 |
+
resnet_groups,
|
1087 |
+
resnet_pre_norm,
|
1088 |
+
output_scale_factor,
|
1089 |
+
add_upsample,
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
self.use_temporal = use_temporal
|
1093 |
+
|
1094 |
+
self.n_frames = n_frames
|
1095 |
+
self.first_frame_condition_mode = first_frame_condition_mode
|
1096 |
+
if self.first_frame_condition_mode == "conv2d":
|
1097 |
+
self.first_frame_conv = nn.Conv2d(latent_channels, prev_output_channel, kernel_size=1)
|
1098 |
+
|
1099 |
+
# >>> Temporal Layers >>>
|
1100 |
+
conv3ds = []
|
1101 |
+
for i in range(num_layers):
|
1102 |
+
if self.use_temporal:
|
1103 |
+
conv3ds.append(
|
1104 |
+
TemporalResnetBlock(
|
1105 |
+
in_channels=out_channels,
|
1106 |
+
out_channels=out_channels,
|
1107 |
+
n_frames=n_frames,
|
1108 |
+
)
|
1109 |
+
)
|
1110 |
+
else:
|
1111 |
+
conv3ds.append(IdentityLayer(return_trans2d_output=False))
|
1112 |
+
|
1113 |
+
self.conv3ds = nn.ModuleList(conv3ds)
|
1114 |
+
# <<< Temporal Layers <<<
|
1115 |
+
|
1116 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1, first_frame_latents=None):
|
1117 |
+
# input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
|
1118 |
+
if self.first_frame_condition_mode == "conv2d":
|
1119 |
+
hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
|
1120 |
+
hidden_height = hidden_states.shape[3]
|
1121 |
+
first_frame_height = first_frame_latents.shape[3]
|
1122 |
+
downsample_ratio = hidden_height / first_frame_height
|
1123 |
+
first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
|
1124 |
+
first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
|
1125 |
+
hidden_states[:, :, 0:1, :, :] = first_frame_latents
|
1126 |
+
hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
|
1127 |
+
|
1128 |
+
for resnet, conv3d in zip(self.resnets, self.conv3ds):
|
1129 |
+
# pop res hidden states
|
1130 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1131 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1132 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1133 |
+
|
1134 |
+
if self.training and self.gradient_checkpointing:
|
1135 |
+
|
1136 |
+
def create_custom_forward(module):
|
1137 |
+
def custom_forward(*inputs):
|
1138 |
+
return module(*inputs)
|
1139 |
+
|
1140 |
+
return custom_forward
|
1141 |
+
|
1142 |
+
if is_torch_version(">=", "1.11.0"):
|
1143 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1144 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
1145 |
+
)
|
1146 |
+
else:
|
1147 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1148 |
+
create_custom_forward(resnet), hidden_states, temb
|
1149 |
+
)
|
1150 |
+
else:
|
1151 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1152 |
+
|
1153 |
+
hidden_states = conv3d(hidden_states)
|
1154 |
+
|
1155 |
+
if self.upsamplers is not None:
|
1156 |
+
for upsampler in self.upsamplers:
|
1157 |
+
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
1158 |
+
|
1159 |
+
return hidden_states
|
src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_autoregress_animation.py
ADDED
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
from typing import Callable, List, Optional, Union
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from torchvision import transforms as T
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from diffusers.utils import is_accelerate_available
|
16 |
+
from packaging import version
|
17 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
18 |
+
|
19 |
+
from diffusers.configuration_utils import FrozenDict
|
20 |
+
from diffusers.models import AutoencoderKL
|
21 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
22 |
+
from diffusers.schedulers import (
|
23 |
+
DDIMScheduler,
|
24 |
+
DPMSolverMultistepScheduler,
|
25 |
+
EulerAncestralDiscreteScheduler,
|
26 |
+
EulerDiscreteScheduler,
|
27 |
+
LMSDiscreteScheduler,
|
28 |
+
PNDMScheduler,
|
29 |
+
)
|
30 |
+
from diffusers.utils import deprecate, logging, BaseOutput
|
31 |
+
|
32 |
+
from einops import rearrange, repeat
|
33 |
+
|
34 |
+
from ..models.unet import UNet3DConditionModel
|
35 |
+
from ..utils.frameinit_utils import freq_mix_3d, get_freq_filter
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
# copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21
|
41 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
42 |
+
"""
|
43 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
44 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
45 |
+
"""
|
46 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
47 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
48 |
+
# rescale the results from guidance (fixes overexposure)
|
49 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
50 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
51 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
52 |
+
return noise_cfg
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class AnimationPipelineOutput(BaseOutput):
|
57 |
+
videos: Union[torch.Tensor, np.ndarray]
|
58 |
+
|
59 |
+
|
60 |
+
class AutoregressiveAnimationPipeline(DiffusionPipeline):
|
61 |
+
_optional_components = []
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
vae: AutoencoderKL,
|
66 |
+
text_encoder: CLIPTextModel,
|
67 |
+
tokenizer: CLIPTokenizer,
|
68 |
+
unet: UNet3DConditionModel,
|
69 |
+
scheduler: Union[
|
70 |
+
DDIMScheduler,
|
71 |
+
PNDMScheduler,
|
72 |
+
LMSDiscreteScheduler,
|
73 |
+
EulerDiscreteScheduler,
|
74 |
+
EulerAncestralDiscreteScheduler,
|
75 |
+
DPMSolverMultistepScheduler,
|
76 |
+
],
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
81 |
+
deprecation_message = (
|
82 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
83 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
84 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
85 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
86 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
87 |
+
" file"
|
88 |
+
)
|
89 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
90 |
+
new_config = dict(scheduler.config)
|
91 |
+
new_config["steps_offset"] = 1
|
92 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
93 |
+
|
94 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
95 |
+
deprecation_message = (
|
96 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
97 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
98 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
99 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
100 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
101 |
+
)
|
102 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
103 |
+
new_config = dict(scheduler.config)
|
104 |
+
new_config["clip_sample"] = False
|
105 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
106 |
+
|
107 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
108 |
+
version.parse(unet.config._diffusers_version).base_version
|
109 |
+
) < version.parse("0.9.0.dev0")
|
110 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
111 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
112 |
+
deprecation_message = (
|
113 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
114 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
115 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
116 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
117 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
118 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
119 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
120 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
121 |
+
" the `unet/config.json` file"
|
122 |
+
)
|
123 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
124 |
+
new_config = dict(unet.config)
|
125 |
+
new_config["sample_size"] = 64
|
126 |
+
unet._internal_dict = FrozenDict(new_config)
|
127 |
+
|
128 |
+
self.register_modules(
|
129 |
+
vae=vae,
|
130 |
+
text_encoder=text_encoder,
|
131 |
+
tokenizer=tokenizer,
|
132 |
+
unet=unet,
|
133 |
+
scheduler=scheduler,
|
134 |
+
)
|
135 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
136 |
+
|
137 |
+
self.freq_filter = None
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def init_filter(self, video_length, height, width, filter_params):
|
141 |
+
# initialize frequency filter for noise reinitialization
|
142 |
+
batch_size = 1
|
143 |
+
num_channels_latents = self.unet.config.in_channels
|
144 |
+
filter_shape = [
|
145 |
+
batch_size,
|
146 |
+
num_channels_latents,
|
147 |
+
video_length,
|
148 |
+
height // self.vae_scale_factor,
|
149 |
+
width // self.vae_scale_factor
|
150 |
+
]
|
151 |
+
# self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params)
|
152 |
+
self.freq_filter = get_freq_filter(
|
153 |
+
filter_shape,
|
154 |
+
device=self._execution_device,
|
155 |
+
filter_type=filter_params.method,
|
156 |
+
n=filter_params.n if filter_params.method=="butterworth" else None,
|
157 |
+
d_s=filter_params.d_s,
|
158 |
+
d_t=filter_params.d_t
|
159 |
+
)
|
160 |
+
|
161 |
+
def enable_vae_slicing(self):
|
162 |
+
self.vae.enable_slicing()
|
163 |
+
|
164 |
+
def disable_vae_slicing(self):
|
165 |
+
self.vae.disable_slicing()
|
166 |
+
|
167 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
168 |
+
if is_accelerate_available():
|
169 |
+
from accelerate import cpu_offload
|
170 |
+
else:
|
171 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
172 |
+
|
173 |
+
device = torch.device(f"cuda:{gpu_id}")
|
174 |
+
|
175 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
176 |
+
if cpu_offloaded_model is not None:
|
177 |
+
cpu_offload(cpu_offloaded_model, device)
|
178 |
+
|
179 |
+
|
180 |
+
@property
|
181 |
+
def _execution_device(self):
|
182 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
183 |
+
return self.device
|
184 |
+
for module in self.unet.modules():
|
185 |
+
if (
|
186 |
+
hasattr(module, "_hf_hook")
|
187 |
+
and hasattr(module._hf_hook, "execution_device")
|
188 |
+
and module._hf_hook.execution_device is not None
|
189 |
+
):
|
190 |
+
return torch.device(module._hf_hook.execution_device)
|
191 |
+
return self.device
|
192 |
+
|
193 |
+
def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
|
194 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
195 |
+
|
196 |
+
text_inputs = self.tokenizer(
|
197 |
+
prompt,
|
198 |
+
padding="max_length",
|
199 |
+
max_length=self.tokenizer.model_max_length,
|
200 |
+
truncation=True,
|
201 |
+
return_tensors="pt",
|
202 |
+
)
|
203 |
+
text_input_ids = text_inputs.input_ids
|
204 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
205 |
+
|
206 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
207 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
208 |
+
logger.warning(
|
209 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
210 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
211 |
+
)
|
212 |
+
|
213 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
214 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
215 |
+
else:
|
216 |
+
attention_mask = None
|
217 |
+
|
218 |
+
text_embeddings = self.text_encoder(
|
219 |
+
text_input_ids.to(device),
|
220 |
+
attention_mask=attention_mask,
|
221 |
+
)
|
222 |
+
text_embeddings = text_embeddings[0]
|
223 |
+
|
224 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
225 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
226 |
+
text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
|
227 |
+
text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
228 |
+
|
229 |
+
# get unconditional embeddings for classifier free guidance
|
230 |
+
if do_classifier_free_guidance is not None:
|
231 |
+
uncond_tokens: List[str]
|
232 |
+
if negative_prompt is None:
|
233 |
+
uncond_tokens = [""] * batch_size
|
234 |
+
elif type(prompt) is not type(negative_prompt):
|
235 |
+
raise TypeError(
|
236 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
237 |
+
f" {type(prompt)}."
|
238 |
+
)
|
239 |
+
elif isinstance(negative_prompt, str):
|
240 |
+
uncond_tokens = [negative_prompt]
|
241 |
+
elif batch_size != len(negative_prompt):
|
242 |
+
raise ValueError(
|
243 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
244 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
245 |
+
" the batch size of `prompt`."
|
246 |
+
)
|
247 |
+
else:
|
248 |
+
uncond_tokens = negative_prompt
|
249 |
+
|
250 |
+
max_length = text_input_ids.shape[-1]
|
251 |
+
uncond_input = self.tokenizer(
|
252 |
+
uncond_tokens,
|
253 |
+
padding="max_length",
|
254 |
+
max_length=max_length,
|
255 |
+
truncation=True,
|
256 |
+
return_tensors="pt",
|
257 |
+
)
|
258 |
+
|
259 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
260 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
261 |
+
else:
|
262 |
+
attention_mask = None
|
263 |
+
|
264 |
+
uncond_embeddings = self.text_encoder(
|
265 |
+
uncond_input.input_ids.to(device),
|
266 |
+
attention_mask=attention_mask,
|
267 |
+
)
|
268 |
+
uncond_embeddings = uncond_embeddings[0]
|
269 |
+
|
270 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
271 |
+
seq_len = uncond_embeddings.shape[1]
|
272 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
|
273 |
+
uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
274 |
+
|
275 |
+
# For classifier free guidance, we need to do two forward passes.
|
276 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
277 |
+
# to avoid doing two forward passes
|
278 |
+
if do_classifier_free_guidance == "text":
|
279 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
280 |
+
elif do_classifier_free_guidance == "both":
|
281 |
+
text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings])
|
282 |
+
|
283 |
+
return text_embeddings
|
284 |
+
|
285 |
+
def decode_latents(self, latents, first_frames=None):
|
286 |
+
video_length = latents.shape[2]
|
287 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
288 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
289 |
+
# video = self.vae.decode(latents).sample
|
290 |
+
video = []
|
291 |
+
for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config):
|
292 |
+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
293 |
+
video = torch.cat(video)
|
294 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
295 |
+
|
296 |
+
if first_frames is not None:
|
297 |
+
first_frames = first_frames.unsqueeze(2)
|
298 |
+
video = torch.cat([first_frames, video], dim=2)
|
299 |
+
|
300 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
301 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
302 |
+
video = video.cpu().float().numpy()
|
303 |
+
return video
|
304 |
+
|
305 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
306 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
307 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
308 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
309 |
+
# and should be between [0, 1]
|
310 |
+
|
311 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
312 |
+
extra_step_kwargs = {}
|
313 |
+
if accepts_eta:
|
314 |
+
extra_step_kwargs["eta"] = eta
|
315 |
+
|
316 |
+
# check if the scheduler accepts generator
|
317 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
318 |
+
if accepts_generator:
|
319 |
+
extra_step_kwargs["generator"] = generator
|
320 |
+
return extra_step_kwargs
|
321 |
+
|
322 |
+
def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None):
|
323 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
324 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
325 |
+
|
326 |
+
if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)):
|
327 |
+
raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}")
|
328 |
+
|
329 |
+
if height % 8 != 0 or width % 8 != 0:
|
330 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
331 |
+
|
332 |
+
if (callback_steps is None) or (
|
333 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
334 |
+
):
|
335 |
+
raise ValueError(
|
336 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
337 |
+
f" {type(callback_steps)}."
|
338 |
+
)
|
339 |
+
|
340 |
+
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0):
|
341 |
+
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
342 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
343 |
+
raise ValueError(
|
344 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
345 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
346 |
+
)
|
347 |
+
if latents is None:
|
348 |
+
rand_device = "cpu" if device.type == "mps" else device
|
349 |
+
|
350 |
+
if isinstance(generator, list):
|
351 |
+
# shape = shape
|
352 |
+
shape = (1,) + shape[1:]
|
353 |
+
if noise_sampling_method == "vanilla":
|
354 |
+
latents = [
|
355 |
+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
356 |
+
for i in range(batch_size)
|
357 |
+
]
|
358 |
+
elif noise_sampling_method == "pyoco_mixed":
|
359 |
+
base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
360 |
+
latents = []
|
361 |
+
noise_alpha_squared = noise_alpha ** 2
|
362 |
+
for i in range(batch_size):
|
363 |
+
base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
|
364 |
+
ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
|
365 |
+
latents.append(base_latent + ind_latent)
|
366 |
+
elif noise_sampling_method == "pyoco_progressive":
|
367 |
+
latents = []
|
368 |
+
noise_alpha_squared = noise_alpha ** 2
|
369 |
+
for i in range(batch_size):
|
370 |
+
latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
371 |
+
ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
|
372 |
+
for j in range(1, video_length):
|
373 |
+
latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :]
|
374 |
+
latents.append(latent)
|
375 |
+
latents = torch.cat(latents, dim=0).to(device)
|
376 |
+
else:
|
377 |
+
if noise_sampling_method == "vanilla":
|
378 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
379 |
+
elif noise_sampling_method == "pyoco_mixed":
|
380 |
+
noise_alpha_squared = noise_alpha ** 2
|
381 |
+
base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
382 |
+
base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
|
383 |
+
ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
|
384 |
+
latents = base_latents + ind_latents
|
385 |
+
elif noise_sampling_method == "pyoco_progressive":
|
386 |
+
noise_alpha_squared = noise_alpha ** 2
|
387 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype)
|
388 |
+
ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
|
389 |
+
for j in range(1, video_length):
|
390 |
+
latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :]
|
391 |
+
else:
|
392 |
+
if latents.shape != shape:
|
393 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
394 |
+
latents = latents.to(device)
|
395 |
+
|
396 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
397 |
+
latents = latents * self.scheduler.init_noise_sigma
|
398 |
+
return latents
|
399 |
+
|
400 |
+
@torch.no_grad()
|
401 |
+
def __call__(
|
402 |
+
self,
|
403 |
+
prompt: Union[str, List[str]],
|
404 |
+
video_length: Optional[int],
|
405 |
+
height: Optional[int] = None,
|
406 |
+
width: Optional[int] = None,
|
407 |
+
num_inference_steps: int = 50,
|
408 |
+
guidance_scale_txt: float = 7.5,
|
409 |
+
guidance_scale_img: float = 2.0,
|
410 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
411 |
+
num_videos_per_prompt: Optional[int] = 1,
|
412 |
+
eta: float = 0.0,
|
413 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
414 |
+
latents: Optional[torch.FloatTensor] = None,
|
415 |
+
output_type: Optional[str] = "tensor",
|
416 |
+
return_dict: bool = True,
|
417 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
418 |
+
callback_steps: Optional[int] = 1,
|
419 |
+
# additional
|
420 |
+
first_frame_paths: Optional[Union[str, List[str]]] = None,
|
421 |
+
first_frames: Optional[torch.FloatTensor] = None,
|
422 |
+
noise_sampling_method: str = "vanilla",
|
423 |
+
noise_alpha: float = 1.0,
|
424 |
+
guidance_rescale: float = 0.0,
|
425 |
+
frame_stride: Optional[int] = None,
|
426 |
+
autoregress_steps: int = 3,
|
427 |
+
use_frameinit: bool = False,
|
428 |
+
frameinit_noise_level: int = 999,
|
429 |
+
**kwargs,
|
430 |
+
):
|
431 |
+
if first_frame_paths is not None and first_frames is not None:
|
432 |
+
raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.")
|
433 |
+
# Default height and width to unet
|
434 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
435 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
436 |
+
|
437 |
+
# Check inputs. Raise error if not correct
|
438 |
+
self.check_inputs(prompt, height, width, callback_steps, first_frame_paths)
|
439 |
+
|
440 |
+
# Define call parameters
|
441 |
+
# batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
442 |
+
batch_size = 1
|
443 |
+
if latents is not None:
|
444 |
+
batch_size = latents.shape[0]
|
445 |
+
if isinstance(prompt, list):
|
446 |
+
batch_size = len(prompt)
|
447 |
+
first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames
|
448 |
+
if first_frame_input is not None:
|
449 |
+
assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length"
|
450 |
+
|
451 |
+
device = self._execution_device
|
452 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
453 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
454 |
+
# corresponds to doing no classifier free guidance.
|
455 |
+
do_classifier_free_guidance = None
|
456 |
+
# two guidance mode: text and text+image
|
457 |
+
if guidance_scale_txt > 1.0:
|
458 |
+
do_classifier_free_guidance = "text"
|
459 |
+
if guidance_scale_img > 1.0:
|
460 |
+
do_classifier_free_guidance = "both"
|
461 |
+
|
462 |
+
# Encode input prompt
|
463 |
+
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
|
464 |
+
if negative_prompt is not None:
|
465 |
+
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
|
466 |
+
text_embeddings = self._encode_prompt(
|
467 |
+
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
|
468 |
+
)
|
469 |
+
|
470 |
+
# Encode input first frame
|
471 |
+
first_frame_latents = None
|
472 |
+
if first_frame_paths is not None:
|
473 |
+
first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size
|
474 |
+
img_transform = T.Compose([
|
475 |
+
T.ToTensor(),
|
476 |
+
T.Resize(height, antialias=None),
|
477 |
+
T.CenterCrop((height, width)),
|
478 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
479 |
+
])
|
480 |
+
first_frames = []
|
481 |
+
for first_frame_path in first_frame_paths:
|
482 |
+
first_frame = Image.open(first_frame_path).convert('RGB')
|
483 |
+
first_frame = img_transform(first_frame).unsqueeze(0)
|
484 |
+
first_frames.append(first_frame)
|
485 |
+
first_frames = torch.cat(first_frames, dim=0)
|
486 |
+
if first_frames is not None:
|
487 |
+
first_frames = first_frames.to(device, dtype=self.vae.dtype)
|
488 |
+
first_frame_latents = self.vae.encode(first_frames).latent_dist
|
489 |
+
first_frame_latents = first_frame_latents.sample()
|
490 |
+
first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w
|
491 |
+
first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
|
492 |
+
first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
|
493 |
+
|
494 |
+
full_video_latent = torch.zeros(batch_size * num_videos_per_prompt, self.unet.config.in_channels, video_length * autoregress_steps - autoregress_steps + 1, height // self.vae_scale_factor, width // self.vae_scale_factor, device=device, dtype=self.vae.dtype)
|
495 |
+
|
496 |
+
start_idx = 0
|
497 |
+
for ar_step in range(autoregress_steps):
|
498 |
+
# Prepare timesteps
|
499 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
500 |
+
timesteps = self.scheduler.timesteps
|
501 |
+
|
502 |
+
# Prepare latent variables
|
503 |
+
num_channels_latents = self.unet.config.in_channels
|
504 |
+
latents = self.prepare_latents(
|
505 |
+
batch_size * num_videos_per_prompt,
|
506 |
+
num_channels_latents,
|
507 |
+
video_length,
|
508 |
+
height,
|
509 |
+
width,
|
510 |
+
text_embeddings.dtype,
|
511 |
+
device,
|
512 |
+
generator,
|
513 |
+
latents,
|
514 |
+
noise_sampling_method,
|
515 |
+
noise_alpha,
|
516 |
+
)
|
517 |
+
latents_dtype = latents.dtype
|
518 |
+
|
519 |
+
if use_frameinit:
|
520 |
+
current_diffuse_timestep = frameinit_noise_level # diffuse to noise level
|
521 |
+
diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep))
|
522 |
+
diffuse_timesteps = diffuse_timesteps.long()
|
523 |
+
first_frames_static_vid = repeat(first_frame_latents, "b c h w -> b c t h w", t=video_length)
|
524 |
+
z_T = self.scheduler.add_noise(
|
525 |
+
original_samples=first_frames_static_vid.to(device),
|
526 |
+
noise=latents.to(device),
|
527 |
+
timesteps=diffuse_timesteps.to(device)
|
528 |
+
)
|
529 |
+
latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents, LPF=self.freq_filter)
|
530 |
+
latents = latents.to(dtype=latents_dtype)
|
531 |
+
|
532 |
+
if first_frame_latents is not None:
|
533 |
+
first_frame_noisy_latent = latents[:, :, 0, :, :]
|
534 |
+
latents = latents[:, :, 1:, :, :]
|
535 |
+
|
536 |
+
# Prepare extra step kwargs.
|
537 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
538 |
+
|
539 |
+
# Denoising loop
|
540 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
541 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
542 |
+
for i, t in enumerate(timesteps):
|
543 |
+
# expand the latents if we are doing classifier free guidance
|
544 |
+
if do_classifier_free_guidance is None:
|
545 |
+
latent_model_input = latents
|
546 |
+
elif do_classifier_free_guidance == "text":
|
547 |
+
latent_model_input = torch.cat([latents] * 2)
|
548 |
+
elif do_classifier_free_guidance == "both":
|
549 |
+
latent_model_input = torch.cat([latents] * 3)
|
550 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
551 |
+
if first_frame_latents is not None:
|
552 |
+
if do_classifier_free_guidance is None:
|
553 |
+
first_frame_latents_input = first_frame_latents
|
554 |
+
elif do_classifier_free_guidance == "text":
|
555 |
+
first_frame_latents_input = torch.cat([first_frame_latents] * 2)
|
556 |
+
elif do_classifier_free_guidance == "both":
|
557 |
+
first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents])
|
558 |
+
|
559 |
+
first_frame_latents_input = first_frame_latents_input.unsqueeze(2)
|
560 |
+
|
561 |
+
# predict the noise residual
|
562 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype)
|
563 |
+
else:
|
564 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
|
565 |
+
# noise_pred = []
|
566 |
+
# import pdb
|
567 |
+
# pdb.set_trace()
|
568 |
+
# for batch_idx in range(latent_model_input.shape[0]):
|
569 |
+
# noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
|
570 |
+
# noise_pred.append(noise_pred_single)
|
571 |
+
# noise_pred = torch.cat(noise_pred)
|
572 |
+
|
573 |
+
# perform guidance
|
574 |
+
if do_classifier_free_guidance:
|
575 |
+
if do_classifier_free_guidance == "text":
|
576 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
577 |
+
noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond)
|
578 |
+
elif do_classifier_free_guidance == "both":
|
579 |
+
noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3)
|
580 |
+
noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img)
|
581 |
+
|
582 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
583 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
584 |
+
# currently only support text guidance
|
585 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
586 |
+
|
587 |
+
# compute the previous noisy sample x_t -> x_t-1
|
588 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
589 |
+
|
590 |
+
# call the callback, if provided
|
591 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
592 |
+
progress_bar.update()
|
593 |
+
if callback is not None and i % callback_steps == 0:
|
594 |
+
callback(i, t, latents)
|
595 |
+
|
596 |
+
# Post-processing
|
597 |
+
|
598 |
+
latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2)
|
599 |
+
first_frame_latents = latents[:, :, -1, :, :]
|
600 |
+
full_video_latent[:, :, start_idx:start_idx + video_length, :, :] = latents
|
601 |
+
|
602 |
+
latents = None
|
603 |
+
start_idx += (video_length - 1)
|
604 |
+
|
605 |
+
# video = self.decode_latents(latents, first_frames)
|
606 |
+
video = self.decode_latents(full_video_latent)
|
607 |
+
|
608 |
+
# Convert to tensor
|
609 |
+
if output_type == "tensor":
|
610 |
+
video = torch.from_numpy(video)
|
611 |
+
|
612 |
+
if not return_dict:
|
613 |
+
return video
|
614 |
+
|
615 |
+
return AnimationPipelineOutput(videos=video)
|
src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_conditional_animation.py
ADDED
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
from typing import Callable, List, Optional, Union
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from torchvision import transforms as T
|
13 |
+
from torchvision.transforms import functional as F
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from diffusers.utils import is_accelerate_available
|
17 |
+
from packaging import version
|
18 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
19 |
+
|
20 |
+
from diffusers.configuration_utils import FrozenDict
|
21 |
+
from diffusers.models import AutoencoderKL
|
22 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
23 |
+
from diffusers.schedulers import (
|
24 |
+
DDIMScheduler,
|
25 |
+
DPMSolverMultistepScheduler,
|
26 |
+
EulerAncestralDiscreteScheduler,
|
27 |
+
EulerDiscreteScheduler,
|
28 |
+
LMSDiscreteScheduler,
|
29 |
+
PNDMScheduler,
|
30 |
+
)
|
31 |
+
from diffusers.utils import deprecate, logging, BaseOutput
|
32 |
+
|
33 |
+
from einops import rearrange, repeat
|
34 |
+
|
35 |
+
from ..models.videoldm_unet import VideoLDMUNet3DConditionModel
|
36 |
+
|
37 |
+
from ..utils.frameinit_utils import get_freq_filter, freq_mix_3d
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
+
|
42 |
+
# copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21
|
43 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
44 |
+
"""
|
45 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
46 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
47 |
+
"""
|
48 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
49 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
50 |
+
# rescale the results from guidance (fixes overexposure)
|
51 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
52 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
53 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
54 |
+
return noise_cfg
|
55 |
+
|
56 |
+
def pan_right(image, num_frames=16, crop_width=256):
|
57 |
+
frames = []
|
58 |
+
height, width = image.shape[-2:]
|
59 |
+
|
60 |
+
for i in range(num_frames):
|
61 |
+
# Calculate the start position of the crop
|
62 |
+
start_x = int((width - crop_width) * (i / num_frames))
|
63 |
+
crop = F.crop(image, 0, start_x, height, crop_width)
|
64 |
+
frames.append(crop.unsqueeze(0))
|
65 |
+
|
66 |
+
return torch.cat(frames, dim=0)
|
67 |
+
|
68 |
+
|
69 |
+
def pan_left(image, num_frames=16, crop_width=256):
|
70 |
+
frames = []
|
71 |
+
height, width = image.shape[-2:]
|
72 |
+
|
73 |
+
for i in range(num_frames):
|
74 |
+
# Start position moves from right to left
|
75 |
+
start_x = int((width - crop_width) * (1 - (i / num_frames)))
|
76 |
+
crop = F.crop(image, 0, start_x, height, crop_width)
|
77 |
+
frames.append(crop.unsqueeze(0))
|
78 |
+
|
79 |
+
return torch.cat(frames, dim=0)
|
80 |
+
|
81 |
+
|
82 |
+
def zoom_in(image, num_frames=16, crop_width=256, ratio=1.5):
|
83 |
+
frames = []
|
84 |
+
height, width = image.shape[-2:]
|
85 |
+
max_crop_size = min(width, height)
|
86 |
+
|
87 |
+
for i in range(num_frames):
|
88 |
+
# Calculate the size of the crop
|
89 |
+
crop_size = max_crop_size - int((max_crop_size - max_crop_size // ratio) * (i / num_frames))
|
90 |
+
start_x = (width - crop_size) // 2
|
91 |
+
start_y = (height - crop_size) // 2
|
92 |
+
crop = F.crop(image, start_y, start_x, crop_size, crop_size)
|
93 |
+
resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size
|
94 |
+
frames.append(resized_crop.unsqueeze(0))
|
95 |
+
|
96 |
+
return torch.cat(frames, dim=0)
|
97 |
+
|
98 |
+
|
99 |
+
def zoom_out(image, num_frames=16, crop_width=256, ratio=1.5):
|
100 |
+
frames = []
|
101 |
+
height, width = image.shape[-2:]
|
102 |
+
min_crop_size = min(width, height) // ratio # Starting from a quarter of the size
|
103 |
+
|
104 |
+
for i in range(num_frames):
|
105 |
+
# Calculate the size of the crop
|
106 |
+
crop_size = min_crop_size + int((min(width, height) - min_crop_size) * (i / num_frames))
|
107 |
+
start_x = (width - crop_size) // 2
|
108 |
+
start_y = (height - crop_size) // 2
|
109 |
+
crop = F.crop(image, start_y, start_x, crop_size, crop_size)
|
110 |
+
resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size
|
111 |
+
frames.append(resized_crop.unsqueeze(0))
|
112 |
+
|
113 |
+
return torch.cat(frames, dim=0)
|
114 |
+
|
115 |
+
|
116 |
+
@dataclass
|
117 |
+
class AnimationPipelineOutput(BaseOutput):
|
118 |
+
videos: Union[torch.Tensor, np.ndarray]
|
119 |
+
|
120 |
+
|
121 |
+
class ConditionalAnimationPipeline(DiffusionPipeline):
|
122 |
+
_optional_components = []
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
vae: AutoencoderKL,
|
127 |
+
text_encoder: CLIPTextModel,
|
128 |
+
tokenizer: CLIPTokenizer,
|
129 |
+
unet: VideoLDMUNet3DConditionModel,
|
130 |
+
scheduler: Union[
|
131 |
+
DDIMScheduler,
|
132 |
+
PNDMScheduler,
|
133 |
+
LMSDiscreteScheduler,
|
134 |
+
EulerDiscreteScheduler,
|
135 |
+
EulerAncestralDiscreteScheduler,
|
136 |
+
DPMSolverMultistepScheduler,
|
137 |
+
],
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
142 |
+
deprecation_message = (
|
143 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
144 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
145 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
146 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
147 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
148 |
+
" file"
|
149 |
+
)
|
150 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
151 |
+
new_config = dict(scheduler.config)
|
152 |
+
new_config["steps_offset"] = 1
|
153 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
154 |
+
|
155 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
156 |
+
deprecation_message = (
|
157 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
158 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
159 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
160 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
161 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
162 |
+
)
|
163 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
164 |
+
new_config = dict(scheduler.config)
|
165 |
+
new_config["clip_sample"] = False
|
166 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
167 |
+
|
168 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
169 |
+
version.parse(unet.config._diffusers_version).base_version
|
170 |
+
) < version.parse("0.9.0.dev0")
|
171 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
172 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
173 |
+
deprecation_message = (
|
174 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
175 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
176 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
177 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
178 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
179 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
180 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
181 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
182 |
+
" the `unet/config.json` file"
|
183 |
+
)
|
184 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
185 |
+
new_config = dict(unet.config)
|
186 |
+
new_config["sample_size"] = 64
|
187 |
+
unet._internal_dict = FrozenDict(new_config)
|
188 |
+
|
189 |
+
self.register_modules(
|
190 |
+
vae=vae,
|
191 |
+
text_encoder=text_encoder,
|
192 |
+
tokenizer=tokenizer,
|
193 |
+
unet=unet,
|
194 |
+
scheduler=scheduler,
|
195 |
+
)
|
196 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
197 |
+
|
198 |
+
self.freq_filter = None
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def init_filter(self, video_length, height, width, filter_params):
|
202 |
+
# initialize frequency filter for noise reinitialization
|
203 |
+
batch_size = 1
|
204 |
+
num_channels_latents = self.unet.config.in_channels
|
205 |
+
filter_shape = [
|
206 |
+
batch_size,
|
207 |
+
num_channels_latents,
|
208 |
+
video_length,
|
209 |
+
height // self.vae_scale_factor,
|
210 |
+
width // self.vae_scale_factor
|
211 |
+
]
|
212 |
+
# self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params)
|
213 |
+
self.freq_filter = get_freq_filter(
|
214 |
+
filter_shape,
|
215 |
+
device=self._execution_device,
|
216 |
+
filter_type=filter_params.method,
|
217 |
+
n=filter_params.n if filter_params.method=="butterworth" else None,
|
218 |
+
d_s=filter_params.d_s,
|
219 |
+
d_t=filter_params.d_t
|
220 |
+
)
|
221 |
+
|
222 |
+
def enable_vae_slicing(self):
|
223 |
+
self.vae.enable_slicing()
|
224 |
+
|
225 |
+
def disable_vae_slicing(self):
|
226 |
+
self.vae.disable_slicing()
|
227 |
+
|
228 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
229 |
+
if is_accelerate_available():
|
230 |
+
from accelerate import cpu_offload
|
231 |
+
else:
|
232 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
233 |
+
|
234 |
+
device = torch.device(f"cuda:{gpu_id}")
|
235 |
+
|
236 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
237 |
+
if cpu_offloaded_model is not None:
|
238 |
+
cpu_offload(cpu_offloaded_model, device)
|
239 |
+
|
240 |
+
|
241 |
+
@property
|
242 |
+
def _execution_device(self):
|
243 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
244 |
+
return self.device
|
245 |
+
for module in self.unet.modules():
|
246 |
+
if (
|
247 |
+
hasattr(module, "_hf_hook")
|
248 |
+
and hasattr(module._hf_hook, "execution_device")
|
249 |
+
and module._hf_hook.execution_device is not None
|
250 |
+
):
|
251 |
+
return torch.device(module._hf_hook.execution_device)
|
252 |
+
return self.device
|
253 |
+
|
254 |
+
def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
|
255 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
256 |
+
|
257 |
+
text_inputs = self.tokenizer(
|
258 |
+
prompt,
|
259 |
+
padding="max_length",
|
260 |
+
max_length=self.tokenizer.model_max_length,
|
261 |
+
truncation=True,
|
262 |
+
return_tensors="pt",
|
263 |
+
)
|
264 |
+
text_input_ids = text_inputs.input_ids
|
265 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
266 |
+
|
267 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
268 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
269 |
+
logger.warning(
|
270 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
271 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
272 |
+
)
|
273 |
+
|
274 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
275 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
276 |
+
else:
|
277 |
+
attention_mask = None
|
278 |
+
|
279 |
+
text_embeddings = self.text_encoder(
|
280 |
+
text_input_ids.to(device),
|
281 |
+
attention_mask=attention_mask,
|
282 |
+
)
|
283 |
+
text_embeddings = text_embeddings[0]
|
284 |
+
|
285 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
286 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
287 |
+
text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
|
288 |
+
text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
289 |
+
|
290 |
+
# get unconditional embeddings for classifier free guidance
|
291 |
+
if do_classifier_free_guidance is not None:
|
292 |
+
uncond_tokens: List[str]
|
293 |
+
if negative_prompt is None:
|
294 |
+
uncond_tokens = [""] * batch_size
|
295 |
+
elif type(prompt) is not type(negative_prompt):
|
296 |
+
raise TypeError(
|
297 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
298 |
+
f" {type(prompt)}."
|
299 |
+
)
|
300 |
+
elif isinstance(negative_prompt, str):
|
301 |
+
uncond_tokens = [negative_prompt]
|
302 |
+
elif batch_size != len(negative_prompt):
|
303 |
+
raise ValueError(
|
304 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
305 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
306 |
+
" the batch size of `prompt`."
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
uncond_tokens = negative_prompt
|
310 |
+
|
311 |
+
max_length = text_input_ids.shape[-1]
|
312 |
+
uncond_input = self.tokenizer(
|
313 |
+
uncond_tokens,
|
314 |
+
padding="max_length",
|
315 |
+
max_length=max_length,
|
316 |
+
truncation=True,
|
317 |
+
return_tensors="pt",
|
318 |
+
)
|
319 |
+
|
320 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
321 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
322 |
+
else:
|
323 |
+
attention_mask = None
|
324 |
+
|
325 |
+
uncond_embeddings = self.text_encoder(
|
326 |
+
uncond_input.input_ids.to(device),
|
327 |
+
attention_mask=attention_mask,
|
328 |
+
)
|
329 |
+
uncond_embeddings = uncond_embeddings[0]
|
330 |
+
|
331 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
332 |
+
seq_len = uncond_embeddings.shape[1]
|
333 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
|
334 |
+
uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
335 |
+
|
336 |
+
# For classifier free guidance, we need to do two forward passes.
|
337 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
338 |
+
# to avoid doing two forward passes
|
339 |
+
if do_classifier_free_guidance == "text":
|
340 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
341 |
+
elif do_classifier_free_guidance == "both":
|
342 |
+
text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings])
|
343 |
+
|
344 |
+
return text_embeddings
|
345 |
+
|
346 |
+
def decode_latents(self, latents, first_frames=None):
|
347 |
+
video_length = latents.shape[2]
|
348 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
349 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
350 |
+
# video = self.vae.decode(latents).sample
|
351 |
+
video = []
|
352 |
+
for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config):
|
353 |
+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
354 |
+
video = torch.cat(video)
|
355 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
356 |
+
|
357 |
+
if first_frames is not None:
|
358 |
+
first_frames = first_frames.unsqueeze(2)
|
359 |
+
video = torch.cat([first_frames, video], dim=2)
|
360 |
+
|
361 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
362 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
363 |
+
video = video.cpu().float().numpy()
|
364 |
+
return video
|
365 |
+
|
366 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
367 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
368 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
369 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
370 |
+
# and should be between [0, 1]
|
371 |
+
|
372 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
373 |
+
extra_step_kwargs = {}
|
374 |
+
if accepts_eta:
|
375 |
+
extra_step_kwargs["eta"] = eta
|
376 |
+
|
377 |
+
# check if the scheduler accepts generator
|
378 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
379 |
+
if accepts_generator:
|
380 |
+
extra_step_kwargs["generator"] = generator
|
381 |
+
return extra_step_kwargs
|
382 |
+
|
383 |
+
def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None):
|
384 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
385 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
386 |
+
|
387 |
+
if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)):
|
388 |
+
raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}")
|
389 |
+
|
390 |
+
if height % 8 != 0 or width % 8 != 0:
|
391 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
392 |
+
|
393 |
+
if (callback_steps is None) or (
|
394 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
395 |
+
):
|
396 |
+
raise ValueError(
|
397 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
398 |
+
f" {type(callback_steps)}."
|
399 |
+
)
|
400 |
+
|
401 |
+
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0):
|
402 |
+
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
403 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
404 |
+
raise ValueError(
|
405 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
406 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
407 |
+
)
|
408 |
+
if latents is None:
|
409 |
+
rand_device = "cpu" if device.type == "mps" else device
|
410 |
+
|
411 |
+
if isinstance(generator, list):
|
412 |
+
# shape = shape
|
413 |
+
shape = (1,) + shape[1:]
|
414 |
+
if noise_sampling_method == "vanilla":
|
415 |
+
latents = [
|
416 |
+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
417 |
+
for i in range(batch_size)
|
418 |
+
]
|
419 |
+
elif noise_sampling_method == "pyoco_mixed":
|
420 |
+
base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
421 |
+
latents = []
|
422 |
+
noise_alpha_squared = noise_alpha ** 2
|
423 |
+
for i in range(batch_size):
|
424 |
+
base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
|
425 |
+
ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
|
426 |
+
latents.append(base_latent + ind_latent)
|
427 |
+
elif noise_sampling_method == "pyoco_progressive":
|
428 |
+
latents = []
|
429 |
+
noise_alpha_squared = noise_alpha ** 2
|
430 |
+
for i in range(batch_size):
|
431 |
+
latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
432 |
+
ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
|
433 |
+
for j in range(1, video_length):
|
434 |
+
latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :]
|
435 |
+
latents.append(latent)
|
436 |
+
latents = torch.cat(latents, dim=0).to(device)
|
437 |
+
else:
|
438 |
+
if noise_sampling_method == "vanilla":
|
439 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
440 |
+
elif noise_sampling_method == "pyoco_mixed":
|
441 |
+
noise_alpha_squared = noise_alpha ** 2
|
442 |
+
base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
443 |
+
base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
|
444 |
+
ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
|
445 |
+
latents = base_latents + ind_latents
|
446 |
+
elif noise_sampling_method == "pyoco_progressive":
|
447 |
+
noise_alpha_squared = noise_alpha ** 2
|
448 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype)
|
449 |
+
ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
|
450 |
+
for j in range(1, video_length):
|
451 |
+
latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :]
|
452 |
+
else:
|
453 |
+
if latents.shape != shape:
|
454 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
455 |
+
latents = latents.to(device)
|
456 |
+
|
457 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
458 |
+
latents = latents * self.scheduler.init_noise_sigma
|
459 |
+
return latents
|
460 |
+
|
461 |
+
@torch.no_grad()
|
462 |
+
def __call__(
|
463 |
+
self,
|
464 |
+
prompt: Union[str, List[str]],
|
465 |
+
video_length: Optional[int],
|
466 |
+
height: Optional[int] = None,
|
467 |
+
width: Optional[int] = None,
|
468 |
+
num_inference_steps: int = 50,
|
469 |
+
guidance_scale_txt: float = 7.5,
|
470 |
+
guidance_scale_img: float = 2.0,
|
471 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
472 |
+
num_videos_per_prompt: Optional[int] = 1,
|
473 |
+
eta: float = 0.0,
|
474 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
475 |
+
latents: Optional[torch.FloatTensor] = None,
|
476 |
+
output_type: Optional[str] = "tensor",
|
477 |
+
return_dict: bool = True,
|
478 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
479 |
+
callback_steps: Optional[int] = 1,
|
480 |
+
# additional
|
481 |
+
first_frame_paths: Optional[Union[str, List[str]]] = None,
|
482 |
+
first_frames: Optional[torch.FloatTensor] = None,
|
483 |
+
noise_sampling_method: str = "vanilla",
|
484 |
+
noise_alpha: float = 1.0,
|
485 |
+
guidance_rescale: float = 0.0,
|
486 |
+
frame_stride: Optional[int] = None,
|
487 |
+
use_frameinit: bool = False,
|
488 |
+
frameinit_noise_level: int = 999,
|
489 |
+
camera_motion: str = None,
|
490 |
+
**kwargs,
|
491 |
+
):
|
492 |
+
if first_frame_paths is not None and first_frames is not None:
|
493 |
+
raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.")
|
494 |
+
# Default height and width to unet
|
495 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
496 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
497 |
+
|
498 |
+
# Check inputs. Raise error if not correct
|
499 |
+
self.check_inputs(prompt, height, width, callback_steps, first_frame_paths)
|
500 |
+
|
501 |
+
# Define call parameters
|
502 |
+
# batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
503 |
+
batch_size = 1
|
504 |
+
if latents is not None:
|
505 |
+
batch_size = latents.shape[0]
|
506 |
+
if isinstance(prompt, list):
|
507 |
+
batch_size = len(prompt)
|
508 |
+
first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames
|
509 |
+
if first_frame_input is not None:
|
510 |
+
assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length"
|
511 |
+
|
512 |
+
device = self._execution_device
|
513 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
514 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
515 |
+
# corresponds to doing no classifier free guidance.
|
516 |
+
do_classifier_free_guidance = None
|
517 |
+
# two guidance mode: text and text+image
|
518 |
+
if guidance_scale_txt > 1.0:
|
519 |
+
do_classifier_free_guidance = "text"
|
520 |
+
if guidance_scale_img > 1.0:
|
521 |
+
do_classifier_free_guidance = "both"
|
522 |
+
|
523 |
+
# Encode input prompt
|
524 |
+
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
|
525 |
+
if negative_prompt is not None:
|
526 |
+
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
|
527 |
+
text_embeddings = self._encode_prompt(
|
528 |
+
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
|
529 |
+
)
|
530 |
+
|
531 |
+
# Encode input first frame
|
532 |
+
first_frame_latents = None
|
533 |
+
if first_frame_paths is not None:
|
534 |
+
first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size
|
535 |
+
if camera_motion is None:
|
536 |
+
img_transform = T.Compose([
|
537 |
+
T.ToTensor(),
|
538 |
+
T.Resize(height, antialias=None),
|
539 |
+
T.CenterCrop((height, width)),
|
540 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
541 |
+
])
|
542 |
+
elif camera_motion == "pan_left" or camera_motion == "pan_right":
|
543 |
+
img_transform = T.Compose([
|
544 |
+
T.ToTensor(),
|
545 |
+
T.Resize(height, antialias=None),
|
546 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
547 |
+
])
|
548 |
+
elif camera_motion == "zoom_out" or camera_motion == "zoom_in":
|
549 |
+
img_transform = T.Compose([
|
550 |
+
T.ToTensor(),
|
551 |
+
T.Resize(height * 2, antialias=None),
|
552 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
553 |
+
])
|
554 |
+
|
555 |
+
first_frames = []
|
556 |
+
for first_frame_path in first_frame_paths:
|
557 |
+
first_frame = Image.open(first_frame_path).convert('RGB')
|
558 |
+
first_frame = img_transform(first_frame)
|
559 |
+
if camera_motion is not None:
|
560 |
+
if camera_motion == "pan_left":
|
561 |
+
first_frame = pan_left(first_frame, num_frames=video_length, crop_width=width)
|
562 |
+
elif camera_motion == "pan_right":
|
563 |
+
first_frame = pan_right(first_frame, num_frames=video_length, crop_width=width)
|
564 |
+
elif camera_motion == "zoom_in":
|
565 |
+
first_frame = zoom_in(first_frame, num_frames=video_length, crop_width=width)
|
566 |
+
elif camera_motion == "zoom_out":
|
567 |
+
first_frame = zoom_out(first_frame, num_frames=video_length, crop_width=width)
|
568 |
+
else:
|
569 |
+
raise NotImplementedError(f"camera_motion: {camera_motion} is not implemented.")
|
570 |
+
first_frames.append(first_frame.unsqueeze(0))
|
571 |
+
first_frames = torch.cat(first_frames, dim=0)
|
572 |
+
if first_frames is not None:
|
573 |
+
first_frames = first_frames.to(device, dtype=self.vae.dtype)
|
574 |
+
if camera_motion is not None:
|
575 |
+
first_frames = rearrange(first_frames, "b f c h w -> (b f) c h w")
|
576 |
+
first_frame_latents = self.vae.encode(first_frames).latent_dist
|
577 |
+
first_frame_latents = first_frame_latents.sample()
|
578 |
+
first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w
|
579 |
+
first_frame_static_vid = rearrange(first_frame_latents, "(b f) c h w -> b c f h w", f=video_length if camera_motion is not None else 1)
|
580 |
+
first_frame_latents = first_frame_static_vid[:, :, 0, :, :]
|
581 |
+
first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
|
582 |
+
first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
|
583 |
+
|
584 |
+
if use_frameinit and camera_motion is None:
|
585 |
+
first_frame_static_vid = repeat(first_frame_static_vid, "b c 1 h w -> b c t h w", t=video_length)
|
586 |
+
|
587 |
+
# self._progress_bar_config = {}
|
588 |
+
# vid = self.decode_latents(first_frame_static_vid)
|
589 |
+
# vid = torch.from_numpy(vid)
|
590 |
+
# from ..utils.util import save_videos_grid
|
591 |
+
# save_videos_grid(vid, "samples/debug/camera_motion/first_frame_static_vid.mp4", fps=8)
|
592 |
+
|
593 |
+
# Prepare timesteps
|
594 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
595 |
+
timesteps = self.scheduler.timesteps
|
596 |
+
|
597 |
+
# Prepare latent variables
|
598 |
+
num_channels_latents = self.unet.config.in_channels
|
599 |
+
latents = self.prepare_latents(
|
600 |
+
batch_size * num_videos_per_prompt,
|
601 |
+
num_channels_latents,
|
602 |
+
video_length,
|
603 |
+
height,
|
604 |
+
width,
|
605 |
+
text_embeddings.dtype,
|
606 |
+
device,
|
607 |
+
generator,
|
608 |
+
latents,
|
609 |
+
noise_sampling_method,
|
610 |
+
noise_alpha,
|
611 |
+
)
|
612 |
+
latents_dtype = latents.dtype
|
613 |
+
|
614 |
+
if use_frameinit:
|
615 |
+
current_diffuse_timestep = frameinit_noise_level # diffuse to t noise level
|
616 |
+
diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep))
|
617 |
+
diffuse_timesteps = diffuse_timesteps.long()
|
618 |
+
z_T = self.scheduler.add_noise(
|
619 |
+
original_samples=first_frame_static_vid.to(device),
|
620 |
+
noise=latents.to(device),
|
621 |
+
timesteps=diffuse_timesteps.to(device)
|
622 |
+
)
|
623 |
+
latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents.to(dtype=torch.float32), LPF=self.freq_filter)
|
624 |
+
latents = latents.to(dtype=latents_dtype)
|
625 |
+
|
626 |
+
if first_frame_latents is not None:
|
627 |
+
first_frame_noisy_latent = latents[:, :, 0, :, :]
|
628 |
+
latents = latents[:, :, 1:, :, :]
|
629 |
+
|
630 |
+
# Prepare extra step kwargs.
|
631 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
632 |
+
|
633 |
+
# Denoising loop
|
634 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
635 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
636 |
+
for i, t in enumerate(timesteps):
|
637 |
+
# expand the latents if we are doing classifier free guidance
|
638 |
+
if do_classifier_free_guidance is None:
|
639 |
+
latent_model_input = latents
|
640 |
+
elif do_classifier_free_guidance == "text":
|
641 |
+
latent_model_input = torch.cat([latents] * 2)
|
642 |
+
elif do_classifier_free_guidance == "both":
|
643 |
+
latent_model_input = torch.cat([latents] * 3)
|
644 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
645 |
+
if first_frame_latents is not None:
|
646 |
+
if do_classifier_free_guidance is None:
|
647 |
+
first_frame_latents_input = first_frame_latents
|
648 |
+
elif do_classifier_free_guidance == "text":
|
649 |
+
first_frame_latents_input = torch.cat([first_frame_latents] * 2)
|
650 |
+
elif do_classifier_free_guidance == "both":
|
651 |
+
first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents])
|
652 |
+
|
653 |
+
first_frame_latents_input = first_frame_latents_input.unsqueeze(2)
|
654 |
+
|
655 |
+
# predict the noise residual
|
656 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype)
|
657 |
+
else:
|
658 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
|
659 |
+
|
660 |
+
# perform guidance
|
661 |
+
if do_classifier_free_guidance:
|
662 |
+
if do_classifier_free_guidance == "text":
|
663 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
664 |
+
noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond)
|
665 |
+
elif do_classifier_free_guidance == "both":
|
666 |
+
noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3)
|
667 |
+
noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img)
|
668 |
+
|
669 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
670 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
671 |
+
# currently only support text guidance
|
672 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
673 |
+
|
674 |
+
# compute the previous noisy sample x_t -> x_t-1
|
675 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
676 |
+
|
677 |
+
# call the callback, if provided
|
678 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
679 |
+
progress_bar.update()
|
680 |
+
if callback is not None and i % callback_steps == 0:
|
681 |
+
callback(i, t, latents)
|
682 |
+
|
683 |
+
# Post-processing
|
684 |
+
latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2)
|
685 |
+
# video = self.decode_latents(latents, first_frames)
|
686 |
+
video = self.decode_latents(latents)
|
687 |
+
|
688 |
+
# Convert to tensor
|
689 |
+
if output_type == "tensor":
|
690 |
+
video = torch.from_numpy(video)
|
691 |
+
|
692 |
+
if not return_dict:
|
693 |
+
return video
|
694 |
+
|
695 |
+
return AnimationPipelineOutput(videos=video)
|
src/videogen_hub/pipelines/consisti2v/consisti2v/utils/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/consisti2v/utils/frameinit_utils.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/TianxingWu/FreeInit/blob/master/freeinit_utils.py
|
2 |
+
import torch
|
3 |
+
import torch.fft as fft
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
def freq_mix_3d(x, noise, LPF):
|
8 |
+
"""
|
9 |
+
Noise reinitialization.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
x: diffused latent
|
13 |
+
noise: randomly sampled noise
|
14 |
+
LPF: low pass filter
|
15 |
+
"""
|
16 |
+
# FFT
|
17 |
+
x_freq = fft.fftn(x, dim=(-3, -2, -1))
|
18 |
+
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
|
19 |
+
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
|
20 |
+
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
|
21 |
+
|
22 |
+
# frequency mix
|
23 |
+
HPF = 1 - LPF
|
24 |
+
x_freq_low = x_freq * LPF
|
25 |
+
noise_freq_high = noise_freq * HPF
|
26 |
+
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
|
27 |
+
|
28 |
+
# IFFT
|
29 |
+
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
|
30 |
+
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
|
31 |
+
|
32 |
+
return x_mixed
|
33 |
+
|
34 |
+
|
35 |
+
def get_freq_filter(shape, device, filter_type, n, d_s, d_t):
|
36 |
+
"""
|
37 |
+
Form the frequency filter for noise reinitialization.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
shape: shape of latent (B, C, T, H, W)
|
41 |
+
filter_type: type of the freq filter
|
42 |
+
n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian
|
43 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
44 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
45 |
+
"""
|
46 |
+
if filter_type == "gaussian":
|
47 |
+
return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
|
48 |
+
elif filter_type == "ideal":
|
49 |
+
return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
|
50 |
+
elif filter_type == "box":
|
51 |
+
return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
|
52 |
+
elif filter_type == "butterworth":
|
53 |
+
return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device)
|
54 |
+
else:
|
55 |
+
raise NotImplementedError
|
56 |
+
|
57 |
+
|
58 |
+
def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25):
|
59 |
+
"""
|
60 |
+
Compute the gaussian low pass filter mask.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
shape: shape of the filter (volume)
|
64 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
65 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
66 |
+
"""
|
67 |
+
T, H, W = shape[-3], shape[-2], shape[-1]
|
68 |
+
mask = torch.zeros(shape)
|
69 |
+
if d_s==0 or d_t==0:
|
70 |
+
return mask
|
71 |
+
for t in range(T):
|
72 |
+
for h in range(H):
|
73 |
+
for w in range(W):
|
74 |
+
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
|
75 |
+
mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square)
|
76 |
+
return mask
|
77 |
+
|
78 |
+
|
79 |
+
def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25):
|
80 |
+
"""
|
81 |
+
Compute the butterworth low pass filter mask.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
shape: shape of the filter (volume)
|
85 |
+
n: order of the filter, larger n ~ ideal, smaller n ~ gaussian
|
86 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
87 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
88 |
+
"""
|
89 |
+
T, H, W = shape[-3], shape[-2], shape[-1]
|
90 |
+
mask = torch.zeros(shape)
|
91 |
+
if d_s==0 or d_t==0:
|
92 |
+
return mask
|
93 |
+
for t in range(T):
|
94 |
+
for h in range(H):
|
95 |
+
for w in range(W):
|
96 |
+
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
|
97 |
+
mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n)
|
98 |
+
return mask
|
99 |
+
|
100 |
+
|
101 |
+
def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25):
|
102 |
+
"""
|
103 |
+
Compute the ideal low pass filter mask.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
shape: shape of the filter (volume)
|
107 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
108 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
109 |
+
"""
|
110 |
+
T, H, W = shape[-3], shape[-2], shape[-1]
|
111 |
+
mask = torch.zeros(shape)
|
112 |
+
if d_s==0 or d_t==0:
|
113 |
+
return mask
|
114 |
+
for t in range(T):
|
115 |
+
for h in range(H):
|
116 |
+
for w in range(W):
|
117 |
+
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
|
118 |
+
mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0
|
119 |
+
return mask
|
120 |
+
|
121 |
+
|
122 |
+
def box_low_pass_filter(shape, d_s=0.25, d_t=0.25):
|
123 |
+
"""
|
124 |
+
Compute the ideal low pass filter mask (approximated version).
|
125 |
+
|
126 |
+
Args:
|
127 |
+
shape: shape of the filter (volume)
|
128 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
129 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
130 |
+
"""
|
131 |
+
T, H, W = shape[-3], shape[-2], shape[-1]
|
132 |
+
mask = torch.zeros(shape)
|
133 |
+
if d_s==0 or d_t==0:
|
134 |
+
return mask
|
135 |
+
|
136 |
+
threshold_s = round(int(H // 2) * d_s)
|
137 |
+
threshold_t = round(T // 2 * d_t)
|
138 |
+
|
139 |
+
cframe, crow, ccol = T // 2, H // 2, W //2
|
140 |
+
mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0
|
141 |
+
|
142 |
+
return mask
|
src/videogen_hub/pipelines/consisti2v/consisti2v/utils/util.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
import torch.distributed as dist
|
9 |
+
import wandb
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from torchmetrics.image.fid import _compute_fid
|
15 |
+
|
16 |
+
|
17 |
+
def zero_rank_print(s):
|
18 |
+
if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
19 |
+
|
20 |
+
|
21 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, wandb=False, global_step=0, format="gif"):
|
22 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
23 |
+
outputs = []
|
24 |
+
for x in videos:
|
25 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
26 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
27 |
+
if rescale:
|
28 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
29 |
+
x = (x * 255).numpy().astype(np.uint8)
|
30 |
+
outputs.append(x)
|
31 |
+
|
32 |
+
if wandb:
|
33 |
+
wandb_video = wandb.Video(outputs, fps=fps)
|
34 |
+
wandb.log({"val_videos": wandb_video}, step=global_step)
|
35 |
+
|
36 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
37 |
+
if format == "gif":
|
38 |
+
imageio.mimsave(path, outputs, fps=fps)
|
39 |
+
elif format == "mp4":
|
40 |
+
torchvision.io.write_video(path, np.array(outputs), fps=fps, video_codec='h264', options={'crf': '10'})
|
41 |
+
|
42 |
+
# DDIM Inversion
|
43 |
+
@torch.no_grad()
|
44 |
+
def init_prompt(prompt, pipeline):
|
45 |
+
uncond_input = pipeline.tokenizer(
|
46 |
+
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
47 |
+
return_tensors="pt"
|
48 |
+
)
|
49 |
+
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
50 |
+
text_input = pipeline.tokenizer(
|
51 |
+
[prompt],
|
52 |
+
padding="max_length",
|
53 |
+
max_length=pipeline.tokenizer.model_max_length,
|
54 |
+
truncation=True,
|
55 |
+
return_tensors="pt",
|
56 |
+
)
|
57 |
+
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
58 |
+
context = torch.cat([uncond_embeddings, text_embeddings])
|
59 |
+
|
60 |
+
return context
|
61 |
+
|
62 |
+
|
63 |
+
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
64 |
+
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
65 |
+
timestep, next_timestep = min(
|
66 |
+
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
67 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
68 |
+
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
69 |
+
beta_prod_t = 1 - alpha_prod_t
|
70 |
+
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
71 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
72 |
+
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
73 |
+
return next_sample
|
74 |
+
|
75 |
+
|
76 |
+
def get_noise_pred_single(latents, t, context, first_frame_latents, frame_stride, unet):
|
77 |
+
noise_pred = unet(latents, t, encoder_hidden_states=context, first_frame_latents=first_frame_latents, frame_stride=frame_stride).sample
|
78 |
+
return noise_pred
|
79 |
+
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt, first_frame_latents, frame_stride):
|
83 |
+
context = init_prompt(prompt, pipeline)
|
84 |
+
uncond_embeddings, cond_embeddings = context.chunk(2)
|
85 |
+
all_latent = [latent]
|
86 |
+
latent = latent.clone().detach()
|
87 |
+
for i in tqdm(range(num_inv_steps)):
|
88 |
+
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
89 |
+
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, first_frame_latents, frame_stride, pipeline.unet)
|
90 |
+
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
91 |
+
all_latent.append(latent)
|
92 |
+
return all_latent
|
93 |
+
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt="", first_frame_latents=None, frame_stride=3):
|
97 |
+
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt, first_frame_latents, frame_stride)
|
98 |
+
return ddim_latents
|
99 |
+
|
100 |
+
|
101 |
+
def compute_fid(real_features, fake_features, num_features, device):
|
102 |
+
orig_dtype = real_features.dtype
|
103 |
+
|
104 |
+
mx_num_feats = (num_features, num_features)
|
105 |
+
real_features_sum = torch.zeros(num_features).double().to(device)
|
106 |
+
real_features_cov_sum = torch.zeros(mx_num_feats).double().to(device)
|
107 |
+
real_features_num_samples = torch.tensor(0).long().to(device)
|
108 |
+
|
109 |
+
fake_features_sum = torch.zeros(num_features).double().to(device)
|
110 |
+
fake_features_cov_sum = torch.zeros(mx_num_feats).double().to(device)
|
111 |
+
fake_features_num_samples = torch.tensor(0).long().to(device)
|
112 |
+
|
113 |
+
real_features = real_features.double()
|
114 |
+
fake_features = fake_features.double()
|
115 |
+
|
116 |
+
real_features_sum += real_features.sum(dim=0)
|
117 |
+
real_features_cov_sum += real_features.t().mm(real_features)
|
118 |
+
real_features_num_samples += real_features.shape[0]
|
119 |
+
|
120 |
+
fake_features_sum += fake_features.sum(dim=0)
|
121 |
+
fake_features_cov_sum += fake_features.t().mm(fake_features)
|
122 |
+
fake_features_num_samples += fake_features.shape[0]
|
123 |
+
|
124 |
+
"""Calculate FID score based on accumulated extracted features from the two distributions."""
|
125 |
+
if real_features_num_samples < 2 or fake_features_num_samples < 2:
|
126 |
+
raise RuntimeError("More than one sample is required for both the real and fake distributed to compute FID")
|
127 |
+
mean_real = (real_features_sum / real_features_num_samples).unsqueeze(0)
|
128 |
+
mean_fake = (fake_features_sum / fake_features_num_samples).unsqueeze(0)
|
129 |
+
|
130 |
+
cov_real_num = real_features_cov_sum - real_features_num_samples * mean_real.t().mm(mean_real)
|
131 |
+
cov_real = cov_real_num / (real_features_num_samples - 1)
|
132 |
+
cov_fake_num = fake_features_cov_sum - fake_features_num_samples * mean_fake.t().mm(mean_fake)
|
133 |
+
cov_fake = cov_fake_num / (fake_features_num_samples - 1)
|
134 |
+
return _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(orig_dtype)
|
135 |
+
|
136 |
+
|
137 |
+
def compute_inception_score(gen_probs, num_splits=10):
|
138 |
+
num_gen = gen_probs.shape[0]
|
139 |
+
gen_probs = gen_probs.detach().cpu().numpy()
|
140 |
+
scores = []
|
141 |
+
np.random.RandomState(42).shuffle(gen_probs)
|
142 |
+
for i in range(num_splits):
|
143 |
+
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
|
144 |
+
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
|
145 |
+
kl = np.mean(np.sum(kl, axis=1))
|
146 |
+
scores.append(np.exp(kl))
|
147 |
+
return float(np.mean(scores)), float(np.std(scores))
|
148 |
+
# idx = torch.randperm(features.shape[0])
|
149 |
+
# features = features[idx]
|
150 |
+
# # calculate probs and logits
|
151 |
+
# prob = features.softmax(dim=1)
|
152 |
+
# log_prob = features.log_softmax(dim=1)
|
153 |
+
|
154 |
+
# # split into groups
|
155 |
+
# prob = prob.chunk(splits, dim=0)
|
156 |
+
# log_prob = log_prob.chunk(splits, dim=0)
|
157 |
+
|
158 |
+
# # calculate score per split
|
159 |
+
# mean_prob = [p.mean(dim=0, keepdim=True) for p in prob]
|
160 |
+
# kl_ = [p * (log_p - m_p.log()) for p, log_p, m_p in zip(prob, log_prob, mean_prob)]
|
161 |
+
# kl_ = [k.sum(dim=1).mean().exp() for k in kl_]
|
162 |
+
# kl = torch.stack(kl_)
|
163 |
+
|
164 |
+
# return mean and std
|
165 |
+
# return kl.mean(), kl.std()
|
src/videogen_hub/pipelines/consisti2v/scripts/__init__.py
ADDED
File without changes
|
src/videogen_hub/pipelines/consisti2v/scripts/animate.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import random
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import diffusers
|
11 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
12 |
+
|
13 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
14 |
+
|
15 |
+
from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel
|
16 |
+
from consisti2v.pipelines.pipeline_conditional_animation import (
|
17 |
+
ConditionalAnimationPipeline,
|
18 |
+
)
|
19 |
+
from consisti2v.utils.util import save_videos_grid
|
20 |
+
from diffusers.utils.import_utils import is_xformers_available
|
21 |
+
|
22 |
+
|
23 |
+
def main(args, config):
|
24 |
+
logging.basicConfig(
|
25 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
26 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
27 |
+
level=logging.INFO,
|
28 |
+
)
|
29 |
+
diffusers.utils.logging.set_verbosity_info()
|
30 |
+
|
31 |
+
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
32 |
+
savedir = f"{config.output_dir}/{config.output_name}-{time_str}"
|
33 |
+
os.makedirs(savedir)
|
34 |
+
|
35 |
+
samples = []
|
36 |
+
sample_idx = 0
|
37 |
+
|
38 |
+
### >>> create validation pipeline >>> ###
|
39 |
+
if config.pipeline_pretrained_path is None:
|
40 |
+
noise_scheduler = DDIMScheduler(
|
41 |
+
**OmegaConf.to_container(config.noise_scheduler_kwargs)
|
42 |
+
)
|
43 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
44 |
+
config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True
|
45 |
+
)
|
46 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
47 |
+
config.pretrained_model_path, subfolder="text_encoder"
|
48 |
+
)
|
49 |
+
vae = AutoencoderKL.from_pretrained(
|
50 |
+
config.pretrained_model_path, subfolder="vae", use_safetensors=True
|
51 |
+
)
|
52 |
+
unet = VideoLDMUNet3DConditionModel.from_pretrained(
|
53 |
+
config.pretrained_model_path,
|
54 |
+
subfolder="unet",
|
55 |
+
variant=config.unet_additional_kwargs["variant"],
|
56 |
+
temp_pos_embedding=config.unet_additional_kwargs["temp_pos_embedding"],
|
57 |
+
augment_temporal_attention=config.unet_additional_kwargs[
|
58 |
+
"augment_temporal_attention"
|
59 |
+
],
|
60 |
+
use_temporal=True,
|
61 |
+
n_frames=config.sampling_kwargs["n_frames"],
|
62 |
+
n_temp_heads=config.unet_additional_kwargs["n_temp_heads"],
|
63 |
+
first_frame_condition_mode=config.unet_additional_kwargs[
|
64 |
+
"first_frame_condition_mode"
|
65 |
+
],
|
66 |
+
use_frame_stride_condition=config.unet_additional_kwargs[
|
67 |
+
"use_frame_stride_condition"
|
68 |
+
],
|
69 |
+
use_safetensors=True,
|
70 |
+
)
|
71 |
+
|
72 |
+
# 1. unet ckpt
|
73 |
+
if config.unet_path is not None:
|
74 |
+
if os.path.isdir(config.unet_path):
|
75 |
+
unet_dict = VideoLDMUNet3DConditionModel.from_pretrained(
|
76 |
+
config.unet_path
|
77 |
+
)
|
78 |
+
m, u = unet.load_state_dict(unet_dict.state_dict(), strict=False)
|
79 |
+
assert len(u) == 0
|
80 |
+
del unet_dict
|
81 |
+
else:
|
82 |
+
checkpoint_dict = torch.load(config.unet_path, map_location="cpu")
|
83 |
+
state_dict = (
|
84 |
+
checkpoint_dict["state_dict"]
|
85 |
+
if "state_dict" in checkpoint_dict
|
86 |
+
else checkpoint_dict
|
87 |
+
)
|
88 |
+
if config.unet_ckpt_prefix is not None:
|
89 |
+
state_dict = {
|
90 |
+
k.replace(config.unet_ckpt_prefix, ""): v
|
91 |
+
for k, v in state_dict.items()
|
92 |
+
}
|
93 |
+
m, u = unet.load_state_dict(state_dict, strict=False)
|
94 |
+
assert len(u) == 0
|
95 |
+
|
96 |
+
if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2:
|
97 |
+
unet.enable_xformers_memory_efficient_attention()
|
98 |
+
|
99 |
+
pipeline = ConditionalAnimationPipeline(
|
100 |
+
vae=vae,
|
101 |
+
text_encoder=text_encoder,
|
102 |
+
tokenizer=tokenizer,
|
103 |
+
unet=unet,
|
104 |
+
scheduler=noise_scheduler,
|
105 |
+
)
|
106 |
+
|
107 |
+
else:
|
108 |
+
pipeline = ConditionalAnimationPipeline.from_pretrained(
|
109 |
+
config.pipeline_pretrained_path
|
110 |
+
)
|
111 |
+
|
112 |
+
pipeline.to("cuda")
|
113 |
+
|
114 |
+
# (frameinit) initialize frequency filter for noise reinitialization -------------
|
115 |
+
if config.frameinit_kwargs.enable:
|
116 |
+
pipeline.init_filter(
|
117 |
+
width=config.sampling_kwargs.width,
|
118 |
+
height=config.sampling_kwargs.height,
|
119 |
+
video_length=config.sampling_kwargs.n_frames,
|
120 |
+
filter_params=config.frameinit_kwargs.filter_params,
|
121 |
+
)
|
122 |
+
# -------------------------------------------------------------------------------
|
123 |
+
### <<< create validation pipeline <<< ###
|
124 |
+
|
125 |
+
if args.prompt is not None:
|
126 |
+
prompts = [args.prompt]
|
127 |
+
n_prompts = [args.n_prompt]
|
128 |
+
first_frame_paths = [args.path_to_first_frame]
|
129 |
+
random_seeds = [int(args.seed)] if args.seed != "random" else "random"
|
130 |
+
else:
|
131 |
+
prompt_config = OmegaConf.load(args.prompt_config)
|
132 |
+
prompts = prompt_config.prompts
|
133 |
+
n_prompts = (
|
134 |
+
list(prompt_config.n_prompts) * len(prompts)
|
135 |
+
if len(prompt_config.n_prompts) == 1
|
136 |
+
else prompt_config.n_prompts
|
137 |
+
)
|
138 |
+
first_frame_paths = prompt_config.path_to_first_frames
|
139 |
+
random_seeds = prompt_config.seeds
|
140 |
+
|
141 |
+
if random_seeds == "random":
|
142 |
+
random_seeds = [random.randint(0, 1e5) for _ in range(len(prompts))]
|
143 |
+
else:
|
144 |
+
random_seeds = (
|
145 |
+
[random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
146 |
+
)
|
147 |
+
random_seeds = (
|
148 |
+
random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
149 |
+
)
|
150 |
+
|
151 |
+
config.prompt_kwargs = OmegaConf.create(
|
152 |
+
{
|
153 |
+
"random_seeds": [],
|
154 |
+
"prompts": prompts,
|
155 |
+
"n_prompts": n_prompts,
|
156 |
+
"first_frame_paths": first_frame_paths,
|
157 |
+
}
|
158 |
+
)
|
159 |
+
for prompt_idx, (prompt, n_prompt, first_frame_path, random_seed) in enumerate(
|
160 |
+
zip(prompts, n_prompts, first_frame_paths, random_seeds)
|
161 |
+
):
|
162 |
+
# manually set random seed for reproduction
|
163 |
+
if random_seed != -1:
|
164 |
+
torch.manual_seed(random_seed)
|
165 |
+
else:
|
166 |
+
torch.seed()
|
167 |
+
config.prompt_kwargs.random_seeds.append(torch.initial_seed())
|
168 |
+
|
169 |
+
print(f"current seed: {torch.initial_seed()}")
|
170 |
+
print(f"sampling {prompt} ...")
|
171 |
+
sample = pipeline(
|
172 |
+
prompt,
|
173 |
+
negative_prompt=n_prompt,
|
174 |
+
first_frame_paths=first_frame_path,
|
175 |
+
num_inference_steps=config.sampling_kwargs.steps,
|
176 |
+
guidance_scale_txt=config.sampling_kwargs.guidance_scale_txt,
|
177 |
+
guidance_scale_img=config.sampling_kwargs.guidance_scale_img,
|
178 |
+
width=config.sampling_kwargs.width,
|
179 |
+
height=config.sampling_kwargs.height,
|
180 |
+
video_length=config.sampling_kwargs.n_frames,
|
181 |
+
noise_sampling_method=config.unet_additional_kwargs[
|
182 |
+
"noise_sampling_method"
|
183 |
+
],
|
184 |
+
noise_alpha=float(config.unet_additional_kwargs["noise_alpha"]),
|
185 |
+
eta=config.sampling_kwargs.ddim_eta,
|
186 |
+
frame_stride=config.sampling_kwargs.frame_stride,
|
187 |
+
guidance_rescale=config.sampling_kwargs.guidance_rescale,
|
188 |
+
num_videos_per_prompt=config.sampling_kwargs.num_videos_per_prompt,
|
189 |
+
use_frameinit=config.frameinit_kwargs.enable,
|
190 |
+
frameinit_noise_level=config.frameinit_kwargs.noise_level,
|
191 |
+
camera_motion=config.frameinit_kwargs.camera_motion,
|
192 |
+
).videos
|
193 |
+
samples.append(sample)
|
194 |
+
|
195 |
+
prompt = "-".join((prompt.replace("/", "").split(" ")[:10])).replace(":", "")
|
196 |
+
if sample.shape[0] > 1:
|
197 |
+
for cnt, samp in enumerate(sample):
|
198 |
+
save_videos_grid(
|
199 |
+
samp.unsqueeze(0),
|
200 |
+
f"{savedir}/sample/{sample_idx}-{cnt + 1}-{prompt}.{args.format}",
|
201 |
+
format=args.format,
|
202 |
+
)
|
203 |
+
else:
|
204 |
+
save_videos_grid(
|
205 |
+
sample,
|
206 |
+
f"{savedir}/sample/{sample_idx}-{prompt}.{args.format}",
|
207 |
+
format=args.format,
|
208 |
+
)
|
209 |
+
print(f"save to {savedir}/sample/{prompt}.{args.format}")
|
210 |
+
|
211 |
+
sample_idx += 1
|
212 |
+
|
213 |
+
samples = torch.concat(samples)
|
214 |
+
# save_videos_grid(samples, f"{savedir}/sample.{args.format}", n_rows=4, format=args.format)
|
215 |
+
|
216 |
+
# OmegaConf.save(config, f"{savedir}/config.yaml")
|
217 |
+
|
218 |
+
# if args.save_model:
|
219 |
+
# pipeline.save_pretrained(f"{savedir}/model")
|
220 |
+
|
221 |
+
return samples
|
222 |
+
|
223 |
+
|
224 |
+
if __name__ == "__main__":
|
225 |
+
parser = argparse.ArgumentParser()
|
226 |
+
parser.add_argument(
|
227 |
+
"--inference_config", type=str, default="configs/inference/inference.yaml"
|
228 |
+
)
|
229 |
+
parser.add_argument("--prompt", "-p", type=str, default=None)
|
230 |
+
parser.add_argument("--n_prompt", "-n", type=str, default="")
|
231 |
+
parser.add_argument("--seed", type=str, default="random")
|
232 |
+
parser.add_argument("--path_to_first_frame", "-f", type=str, default=None)
|
233 |
+
parser.add_argument(
|
234 |
+
"--prompt_config", type=str, default="configs/prompts/default.yaml"
|
235 |
+
)
|
236 |
+
parser.add_argument("--format", type=str, default="mp4", choices=["gif", "mp4"])
|
237 |
+
parser.add_argument("--save_model", action="store_true")
|
238 |
+
parser.add_argument("optional_args", nargs="*", default=[])
|
239 |
+
args = parser.parse_args()
|
240 |
+
|
241 |
+
config = OmegaConf.load(args.inference_config)
|
242 |
+
|
243 |
+
if args.optional_args:
|
244 |
+
modified_config = OmegaConf.from_dotlist(args.optional_args)
|
245 |
+
config = OmegaConf.merge(config, modified_config)
|
246 |
+
|
247 |
+
main(args, config)
|