diff --git a/src/videogen_hub/pipelines/__init__.py b/src/videogen_hub/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/cogvideo/__init__.py b/src/videogen_hub/pipelines/cogvideo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d02090a45f74c04b55eef1cd0502e642b7ec7efc --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/__init__.py @@ -0,0 +1,4 @@ +import sys + +sys.path.insert(0, "./src/videogen_hub/pipelines/cogvideo/") +sys.path.insert(0, "./src/videogen_hub/pipelines/cogvideo/cogvideo_src") diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_pipeline.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2e7c42197c0824a602a11b3e664280bc1a8d40 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_pipeline.py @@ -0,0 +1,612 @@ +from videogen_hub.pipelines.cogvideo.cogvideo_src.cogvideo_pipeline import ( + InferenceModel_Interpolate, + InferenceModel_Sequential, + my_filling_sequence, + get_masks_and_position_ids_stage1, + get_masks_and_position_ids_stage2, + my_save_multiple_images, +) +from videogen_hub.depend.icetk import icetk as tokenizer +from videogen_hub.pipelines.cogvideo.cogvideo_src.coglm_strategy import ( + CoglmStrategy, +) +from videogen_hub.pipelines.cogvideo.cogvideo_src.sr_pipeline import ( + DirectSuperResolution, +) +from SwissArmyTransformer.resources import auto_create +import time, logging, sys, os, torch +import torch.distributed as dist + +# path = os.path.join(args.output_path, f"{now_qi}_{raw_text}") + + +def pipeline(args, raw_text, height, width, duration): + # model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1') + # model_stage1.eval() + # parent_givan_tokens = process_stage1(model_stage1, raw_text, duration=4.0, video_raw_text=raw_text, video_guidance_text="视频", + # image_text_suffix=" 高清摄影", + # outputdir=None, batch_size=args.batch_size) + + # process_stage2(model_stage2, raw_text, duration=2.0, video_raw_text=raw_text+" 视频", + # video_guidance_text="视频", parent_given_tokens=parent_given_tokens, + # outputdir=path, + # gpu_rank=0, gpu_parallel_size=1) # TODO: 修改 + + assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1 + rank_id = args.device % args.parallel_size + generate_frame_num = args.generate_frame_num + + if args.stage_1 or args.both_stages: + model_stage1, args = InferenceModel_Sequential.from_pretrained( + args, "cogvideo-stage1" + ) + model_stage1.eval() + if args.both_stages: + model_stage1 = model_stage1.cpu() + + if args.stage_2 or args.both_stages: + model_stage2, args = InferenceModel_Interpolate.from_pretrained( + args, "cogvideo-stage2" + ) + model_stage2.eval() + if args.both_stages: + model_stage2 = model_stage2.cpu() + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + strategy_cogview2 = CoglmStrategy(invalid_slices, temperature=1.0, top_k=16) + strategy_cogvideo = CoglmStrategy( + invalid_slices, + temperature=args.temperature, + top_k=args.top_k, + temperature2=args.coglm_temperature2, + ) + if not args.stage_1: + # from sr_pipeline import DirectSuperResolution + dsr_path = auto_create( + "cogview2-dsr", path=None + ) # path=os.getenv('SAT_HOME', '~/.sat_models') + dsr = DirectSuperResolution(args, dsr_path, max_bz=12, onCUDA=False) + + def process_stage2( + model, + seq_text, + duration, + video_raw_text=None, + video_guidance_text="视频", + parent_given_tokens=None, + conddir=None, + outputdir=None, + gpu_rank=0, + gpu_parallel_size=1, + ): + stage2_starttime = time.time() + use_guidance = args.use_guidance_stage2 + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage-2 model to cuda") + model = model.cuda() + logging.debug( + "moving in stage-2 model takes time: {:.2f}".format( + time.time() - move_start_time + ) + ) + + try: + if parent_given_tokens is None: + assert conddir is not None + parent_given_tokens = torch.load( + os.path.join(conddir, "frame_tokens.pt"), map_location="cpu" + ) + sample_num_allgpu = parent_given_tokens.shape[0] + sample_num = sample_num_allgpu // gpu_parallel_size + assert sample_num * gpu_parallel_size == sample_num_allgpu + parent_given_tokens = parent_given_tokens[ + gpu_rank * sample_num : (gpu_rank + 1) * sample_num + ] + except: + logging.critical("No frame_tokens found in interpolation, skip") + return False + + # CogVideo Stage2 Generation + while ( + duration >= 0.5 + ): # TODO: You can change the boundary to change the frame rate + parent_given_tokens_num = parent_given_tokens.shape[1] + generate_batchsize_persample = (parent_given_tokens_num - 1) // 2 + generate_batchsize_total = generate_batchsize_persample * sample_num + total_frames = generate_frame_num + frame_len = 400 + enc_text = tokenizer.encode(seq_text) + enc_duration = tokenizer.encode(str(float(duration)) + "秒") + seq = ( + enc_duration + + [tokenizer[""]] + + enc_text + + [tokenizer[""]] + + [-1] * 400 * generate_frame_num + ) + text_len = len(seq) - frame_len * generate_frame_num - 1 + + logging.info( + "[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format( + int(4 / duration), tokenizer.decode(enc_text) + ) + ) + + # generation + seq = ( + torch.cuda.LongTensor(seq, device=args.device) + .unsqueeze(0) + .repeat(generate_batchsize_total, 1) + ) + for sample_i in range(sample_num): + for i in range(generate_batchsize_persample): + seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 : text_len + 1 + 400 + ] = parent_given_tokens[sample_i][2 * i] + seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 + 400 : text_len + 1 + 800 + ] = parent_given_tokens[sample_i][2 * i + 1] + seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 + 800 : text_len + 1 + 1200 + ] = parent_given_tokens[sample_i][2 * i + 2] + + if use_guidance: + guider_seq = ( + enc_duration + + [tokenizer[""]] + + tokenizer.encode(video_guidance_text) + + [tokenizer[""]] + + [-1] * 400 * generate_frame_num + ) + guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1 + guider_seq = ( + torch.cuda.LongTensor(guider_seq, device=args.device) + .unsqueeze(0) + .repeat(generate_batchsize_total, 1) + ) + for sample_i in range(sample_num): + for i in range(generate_batchsize_persample): + guider_seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 : text_len + 1 + 400 + ] = parent_given_tokens[sample_i][2 * i] + guider_seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 + 400 : text_len + 1 + 800 + ] = parent_given_tokens[sample_i][2 * i + 1] + guider_seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 + 800 : text_len + 1 + 1200 + ] = parent_given_tokens[sample_i][2 * i + 2] + video_log_text_attention_weights = 0 + else: + guider_seq = None + guider_text_len = 0 + video_log_text_attention_weights = 1.4 + + mbz = args.max_inference_batch_size + + assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0 + output_list = [] + start_time = time.time() + for tim in range(max(generate_batchsize_total // mbz, 1)): + input_seq = ( + seq[: min(generate_batchsize_total, mbz)].clone() + if tim == 0 + else seq[mbz * tim : mbz * (tim + 1)].clone() + ) + guider_seq2 = ( + ( + guider_seq[: min(generate_batchsize_total, mbz)].clone() + if tim == 0 + else guider_seq[mbz * tim : mbz * (tim + 1)].clone() + ) + if guider_seq is not None + else None + ) + output_list.append( + my_filling_sequence( + model, + args, + input_seq, + batch_size=min(generate_batchsize_total, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage2, + text_len=text_len, + frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=video_log_text_attention_weights, + mode_stage1=False, + guider_seq=guider_seq2, + guider_text_len=guider_text_len, + guidance_alpha=args.guidance_alpha, + limited_spatial_channel_mem=True, + )[0] + ) + logging.info( + "Duration {:.2f}, Taken time {:.2f}\n".format( + duration, time.time() - start_time + ) + ) + + output_tokens = torch.cat(output_list, dim=0) + output_tokens = output_tokens[ + :, text_len + 1 : text_len + 1 + (total_frames) * 400 + ].reshape(sample_num, -1, 400 * total_frames) + output_tokens_merge = torch.cat( + ( + output_tokens[:, :, : 1 * 400], + output_tokens[:, :, 400 * 3 : 4 * 400], + output_tokens[:, :, 400 * 1 : 2 * 400], + output_tokens[:, :, 400 * 4 : (total_frames) * 400], + ), + dim=2, + ).reshape(sample_num, -1, 400) + + output_tokens_merge = torch.cat( + (output_tokens_merge, output_tokens[:, -1:, 400 * 2 : 3 * 400]), dim=1 + ) + duration /= 2 + parent_given_tokens = output_tokens_merge + + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 2 model to cpu") + model = model.cpu() + torch.cuda.empty_cache() + logging.debug( + "moving out model2 takes time: {:.2f}".format( + time.time() - move_start_time + ) + ) + + logging.info( + "CogVideo Stage2 completed. Taken time {:.2f}\n".format( + time.time() - stage2_starttime + ) + ) + + # decoding + # imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge] + # os.makedirs(output_dir_full_path, exist_ok=True) + # my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False) + # torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt')) + # os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2") + + # direct super-resolution by CogView2 + logging.info("[Direct super-resolution]") + dsr_starttime = time.time() + enc_text = tokenizer.encode(seq_text) + frame_num_per_sample = parent_given_tokens.shape[1] + parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400) + text_seq = ( + torch.cuda.LongTensor(enc_text, device=args.device) + .unsqueeze(0) + .repeat(parent_given_tokens_2d.shape[0], 1) + ) + sred_tokens = dsr(text_seq, parent_given_tokens_2d) + decoded_sr_videos = [] + + for sample_i in range(sample_num): + decoded_sr_imgs = [] + for frame_i in range(frame_num_per_sample): + decoded_sr_img = tokenizer.decode( + image_ids=sred_tokens[frame_i + sample_i * frame_num_per_sample][ + -3600: + ] + ) + decoded_sr_imgs.append( + torch.nn.functional.interpolate( + decoded_sr_img, size=(height, width) + ) + ) + decoded_sr_videos.append(decoded_sr_imgs) + + return decoded_sr_videos + # for sample_i in range(sample_num): + # my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) + # os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125") + + # logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime)) + + # return True + + def process_stage1( + model, + seq_text, + duration, + video_raw_text=None, + video_guidance_text="视频", + image_text_suffix="", + outputdir=None, + batch_size=1, + ): + process_start_time = time.time() + use_guide = args.use_guidance_stage1 + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 1 model to cuda") + model = model.cuda() + logging.debug( + "moving in model1 takes time: {:.2f}".format( + time.time() - move_start_time + ) + ) + + if video_raw_text is None: + video_raw_text = seq_text + mbz = ( + args.stage1_max_inference_batch_size + if args.stage1_max_inference_batch_size > 0 + else args.max_inference_batch_size + ) + assert batch_size < mbz or batch_size % mbz == 0 + frame_len = 400 + + # generate the first frame: + enc_text = tokenizer.encode(seq_text + image_text_suffix) + seq_1st = ( + enc_text + [tokenizer[""]] + [-1] * 400 + ) # IV!! # test local!!! # test randboi!!! + logging.info( + "[Generating First Frame with CogView2]Raw text: {:s}".format( + tokenizer.decode(enc_text) + ) + ) + text_len_1st = len(seq_1st) - frame_len * 1 - 1 + + seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0) + output_list_1st = [] + for tim in range(max(batch_size // mbz, 1)): + start_time = time.time() + output_list_1st.append( + my_filling_sequence( + model, + args, + seq_1st.clone(), + batch_size=min(batch_size, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage1, + text_len=text_len_1st, + frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=1.4, + enforce_no_swin=True, + mode_stage1=True, + )[0] + ) + logging.info( + "[First Frame]Taken time {:.2f}\n".format(time.time() - start_time) + ) + output_tokens_1st = torch.cat(output_list_1st, dim=0) + given_tokens = output_tokens_1st[ + :, text_len_1st + 1 : text_len_1st + 401 + ].unsqueeze( + 1 + ) # given_tokens.shape: [bs, frame_num, 400] + + # generate subsequent frames: + total_frames = generate_frame_num + enc_duration = tokenizer.encode(str(float(duration)) + "秒") + if use_guide: + video_raw_text = video_raw_text + " 视频" + enc_text_video = tokenizer.encode(video_raw_text) + seq = ( + enc_duration + + [tokenizer[""]] + + enc_text_video + + [tokenizer[""]] + + [-1] * 400 * generate_frame_num + ) + guider_seq = ( + enc_duration + + [tokenizer[""]] + + tokenizer.encode(video_guidance_text) + + [tokenizer[""]] + + [-1] * 400 * generate_frame_num + ) + logging.info( + "[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format( + 4 / duration, tokenizer.decode(enc_text_video) + ) + ) + + text_len = len(seq) - frame_len * generate_frame_num - 1 + guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1 + seq = ( + torch.cuda.LongTensor(seq, device=args.device) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + guider_seq = ( + torch.cuda.LongTensor(guider_seq, device=args.device) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + + for given_frame_id in range(given_tokens.shape[1]): + seq[ + :, + text_len + + 1 + + given_frame_id * 400 : text_len + + 1 + + (given_frame_id + 1) * 400, + ] = given_tokens[:, given_frame_id] + guider_seq[ + :, + guider_text_len + + 1 + + given_frame_id * 400 : guider_text_len + + 1 + + (given_frame_id + 1) * 400, + ] = given_tokens[:, given_frame_id] + output_list = [] + + if use_guide: + video_log_text_attention_weights = 0 + else: + guider_seq = None + video_log_text_attention_weights = 1.4 + + for tim in range(max(batch_size // mbz, 1)): + start_time = time.time() + input_seq = ( + seq[: min(batch_size, mbz)].clone() + if tim == 0 + else seq[mbz * tim : mbz * (tim + 1)].clone() + ) + guider_seq2 = ( + ( + guider_seq[: min(batch_size, mbz)].clone() + if tim == 0 + else guider_seq[mbz * tim : mbz * (tim + 1)].clone() + ) + if guider_seq is not None + else None + ) + output_list.append( + my_filling_sequence( + model, + args, + input_seq, + batch_size=min(batch_size, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage1, + text_len=text_len, + frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=video_log_text_attention_weights, + guider_seq=guider_seq2, + guider_text_len=guider_text_len, + guidance_alpha=args.guidance_alpha, + limited_spatial_channel_mem=True, + mode_stage1=True, + )[0] + ) + + output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len :] + + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 1 model to cpu") + model = model.cpu() + torch.cuda.empty_cache() + logging.debug( + "moving in model1 takes time: {:.2f}".format( + time.time() - move_start_time + ) + ) + + # decoding + imgs, sred_imgs, txts = [], [], [] + for seq in output_tokens: + decoded_imgs = [ + torch.nn.functional.interpolate( + tokenizer.decode(image_ids=seq.tolist()[i * 400 : (i + 1) * 400]), + size=(height, width), + ) + for i in range(total_frames) + ] + imgs.append(decoded_imgs) # only the last image (target) + + assert len(imgs) == batch_size + return imgs + # save_tokens = output_tokens[:, :+total_frames*400].reshape(-1, total_frames, 400).cpu() + # if outputdir is not None: + # for clip_i in range(len(imgs)): + # # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True) + # my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False) + # os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25") + # torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt')) + + # logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time)) + + # return save_tokens + + # ====================================================================================================== + + if args.stage_1 or args.both_stages: + if args.input_source != "interactive": + with open(args.input_source, "r") as fin: + promptlist = fin.readlines() + promptlist = [p.strip() for p in promptlist] + else: + promptlist = None + + now_qi = -1 + while True: + now_qi += 1 + + if promptlist is not None: # with input-source + if args.multi_gpu: + if now_qi % dist.get_world_size() != dist.get_rank(): + continue + rk = dist.get_rank() + else: + rk = 0 + raw_text = promptlist[now_qi] + raw_text = raw_text.strip() + print(f"Working on Line No. {now_qi} on {rk}... [{raw_text}]") + else: # interactive + raw_text = input("\nPlease Input Query (stop to exit) >>> ") + raw_text = raw_text.strip() + if not raw_text: + print("Query should not be empty!") + continue + if raw_text == "stop": + return + + try: + path = os.path.join(args.output_path, f"{now_qi}_{raw_text}") + parent_given_tokens, imgs = process_stage1( + model_stage1, + raw_text, + duration=4.0, + video_raw_text=raw_text, + video_guidance_text="视频", + image_text_suffix=" 高清摄影", + outputdir=path if args.stage_1 else None, + batch_size=args.batch_size, + ) + if args.stage_1 and not args.both_stages: + print("only stage 1") + return imgs + + if args.both_stages: + videos = process_stage2( + model_stage2, + raw_text, + duration=duration, + video_raw_text=raw_text + " 视频", + video_guidance_text="视频", + parent_given_tokens=parent_given_tokens, + outputdir=path, + gpu_rank=0, + gpu_parallel_size=1, + ) # TODO: 修改 + return videos + except (ValueError, FileNotFoundError) as e: + print(e) + continue + + elif args.stage_2: + sample_dirs = os.listdir(args.output_path) + for sample in sample_dirs: + raw_text = sample.split("_")[-1] + path = os.path.join(args.output_path, sample, "Interp") + parent_given_tokens = torch.load( + os.path.join(args.output_path, sample, "frame_tokens.pt") + ) + + process_stage2( + raw_text, + duration=2.0, + video_raw_text=raw_text + " 视频", + video_guidance_text="视频", + parent_given_tokens=parent_given_tokens, + outputdir=path, + gpu_rank=0, + gpu_parallel_size=1, + ) # TODO: 修改 + + else: + assert False diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/LICENSE b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/Model_License b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/Model_License new file mode 100644 index 0000000000000000000000000000000000000000..2a21b75798e44fa936f4fb31c5765489c7ae1193 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/Model_License @@ -0,0 +1,79 @@ +The CogVideo License + +Section I: PREAMBLE + +Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and Licensor agree as follows: + +1. Definitions + +- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +- "Output" means the results of operating a Model as embodied in informational content resulting therefrom. +- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. +- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. +- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. +- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." +- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +You must cause any modified files to carry prominent notices stating that You changed the files; +You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). +6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. +8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. +9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. +10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. +11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +END OF TERMS AND CONDITIONS + + + + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: +- In any way that violates any applicable national, federal, state, local or international law or regulation; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate personal identifiable information that can be used to harm an individual; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; +- To provide medical advice and medical results interpretation; +- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/__init__.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/cluster_label2.npy b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/cluster_label2.npy new file mode 100644 index 0000000000000000000000000000000000000000..1c27ec8f73830ac8611789750dbfd73a2a494920 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/cluster_label2.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b87880fdbe89670f12844377b9cf97a9733b1f54e3a9b73cbb9835084c4e02ec +size 160128 diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/coglm_strategy.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/coglm_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..d4857156d1a6c0bf4791bab033b4dd0f468e5889 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/coglm_strategy.py @@ -0,0 +1,101 @@ +# -*- encoding: utf-8 -*- +''' +@File : coglm_strategy.py +@Time : 2021/10/08 22:22:42 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import numpy as np +import torch.nn.functional as F + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + + return logits + + +class CoglmStrategy: + def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, temperature2=0.89): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.temperature2 = temperature2 + self.topk = top_k + self.top_p = top_p + self.eps = eps + if end_tokens is None: + end_tokens = [] + self.end_tokens = end_tokens + self._is_done = False + self.outlier_count_down = torch.zeros(16) + self.vis_list = [[]for i in range(16)] + self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long) + self.start_pos = -1 + self.white_cluster = [] + # self.fout = open('tmp.txt', 'w') + + @property + def is_done(self) -> bool: + return self._is_done + + def forward(self, logits, tokens, mems, temperature=None, temperature2=None): + if temperature is None: + temperature = self.temperature + if temperature2 is None: + temperature2 = self.temperature2 + logits = logits / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -65504 + + rprobs = F.softmax(logits.float(), dim=-1) + c = self.cluster_labels.expand(*rprobs.shape) + cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs) + # self.fout.write(str(tokens.shape[-1])+ ' ' + str(cprobs.topk(10)) + '\n') + # self.fout.flush() + best_scores, best_clusters = cprobs.topk(self.topk) + bz = logits.shape[0] + for i in range(bz): + selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)] + logits[i, self.cluster_labels != selected_cluster] = -65504 + + # logits = top_k_logits(logits, self.topk, self.top_p) + probs = F.softmax(logits.float()/temperature2, dim=-1) # float is essetial, due to a bug in Pytorch + pred = torch.multinomial(probs, num_samples=1) + + if pred.numel() == 1 and pred.item() in self.end_tokens: + self._is_done = True + tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1) + return tokens, mems + + def finalize(self, tokens, mems): + self._is_done = False + return tokens, mems \ No newline at end of file diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/cogvideo_pipeline.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/cogvideo_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..158e0785a1d7d3f353385dd066dfaeaeefeb165e --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/cogvideo_pipeline.py @@ -0,0 +1,1341 @@ +# -*- encoding: utf-8 -*- +""" +@File : cogvideo_pipeline.py +@Time : 2022/07/15 11:24:56 +@Author : Wenyi Hong +@Version : 1.0 +@Contact : hwy22@mails.tsinghua.edu.cn +""" + +# here put the import lib + +import os +import sys +import torch +import argparse +import time +from torchvision.utils import save_image +import stat +from videogen_hub.depend.icetk import icetk as tokenizer +import logging, sys + +import torch.distributed as dist + +tokenizer.add_special_tokens( + ["", "", ""] +) + + +from SwissArmyTransformer import get_args +from SwissArmyTransformer.data_utils import BinaryDataset, make_loaders +from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy +from SwissArmyTransformer.generation.utils import ( + timed_name, + save_multiple_images, + generate_continually, +) +from SwissArmyTransformer.resources import auto_create + +from .models.cogvideo_cache_model import CogVideoCacheModel +from .coglm_strategy import CoglmStrategy + + +def get_masks_and_position_ids_stage1(data, textlen, framelen): + # Extract batch size and sequence length. + tokens = data + seq_length = len(data[0]) + # Attention mask (lower triangular). + attention_mask = torch.ones( + (1, textlen + framelen, textlen + framelen), device=data.device + ) + attention_mask[:, :textlen, textlen:] = 0 + attention_mask[:, textlen:, textlen:].tril_() + attention_mask.unsqueeze_(1) + # Unaligned version + position_ids = torch.zeros(seq_length, dtype=torch.long, device=data.device) + torch.arange( + textlen, out=position_ids[:textlen], dtype=torch.long, device=data.device + ) + torch.arange( + 512, + 512 + seq_length - textlen, + out=position_ids[textlen:], + dtype=torch.long, + device=data.device, + ) + position_ids = position_ids.unsqueeze(0) + + return tokens, attention_mask, position_ids + + +def get_masks_and_position_ids_stage2(data, textlen, framelen): + # Extract batch size and sequence length. + tokens = data + seq_length = len(data[0]) + + # Attention mask (lower triangular). + attention_mask = torch.ones( + (1, textlen + framelen, textlen + framelen), device=data.device + ) + attention_mask[:, :textlen, textlen:] = 0 + attention_mask[:, textlen:, textlen:].tril_() + attention_mask.unsqueeze_(1) + + # Unaligned version + position_ids = torch.zeros(seq_length, dtype=torch.long, device=data.device) + torch.arange( + textlen, out=position_ids[:textlen], dtype=torch.long, device=data.device + ) + frame_num = (seq_length - textlen) // framelen + assert frame_num == 5 + torch.arange( + 512, + 512 + framelen, + out=position_ids[textlen : textlen + framelen], + dtype=torch.long, + device=data.device, + ) + torch.arange( + 512 + framelen * 2, + 512 + framelen * 3, + out=position_ids[textlen + framelen : textlen + framelen * 2], + dtype=torch.long, + device=data.device, + ) + torch.arange( + 512 + framelen * (frame_num - 1), + 512 + framelen * frame_num, + out=position_ids[textlen + framelen * 2 : textlen + framelen * 3], + dtype=torch.long, + device=data.device, + ) + torch.arange( + 512 + framelen * 1, + 512 + framelen * 2, + out=position_ids[textlen + framelen * 3 : textlen + framelen * 4], + dtype=torch.long, + device=data.device, + ) + torch.arange( + 512 + framelen * 3, + 512 + framelen * 4, + out=position_ids[textlen + framelen * 4 : textlen + framelen * 5], + dtype=torch.long, + device=data.device, + ) + + position_ids = position_ids.unsqueeze(0) + + return tokens, attention_mask, position_ids + + +def my_update_mems( + hiddens, mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len +): + if hiddens is None: + return None, mems_indexs + mem_num = len(hiddens) + ret_mem = [] + with torch.no_grad(): + for id in range(mem_num): + if hiddens[id][0] is None: + ret_mem.append(None) + else: + if ( + id == 0 + and limited_spatial_channel_mem + and mems_indexs[id] + hiddens[0][0].shape[1] >= text_len + frame_len + ): + if mems_indexs[id] == 0: + for layer, hidden in enumerate(hiddens[id]): + mems_buffers[id][layer, :, :text_len] = hidden.expand( + mems_buffers[id].shape[1], -1, -1 + )[:, :text_len] + new_mem_len_part2 = ( + mems_indexs[id] + hiddens[0][0].shape[1] - text_len + ) % frame_len + if new_mem_len_part2 > 0: + for layer, hidden in enumerate(hiddens[id]): + mems_buffers[id][ + layer, :, text_len : text_len + new_mem_len_part2 + ] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[ + :, -new_mem_len_part2: + ] + mems_indexs[id] = text_len + new_mem_len_part2 + else: + for layer, hidden in enumerate(hiddens[id]): + mems_buffers[id][ + layer, + :, + mems_indexs[id] : mems_indexs[id] + hidden.shape[1], + ] = hidden.expand(mems_buffers[id].shape[1], -1, -1) + mems_indexs[id] += hidden.shape[1] + ret_mem.append(mems_buffers[id][:, :, : mems_indexs[id]]) + return ret_mem, mems_indexs + + +def my_save_multiple_images(imgs, path, subdir, debug=True): + # imgs: list of tensor images + if debug: + imgs = torch.cat(imgs, dim=0) + print("\nSave to: ", path, flush=True) + save_image(imgs, path, normalize=True) + else: + print("\nSave to: ", path, flush=True) + single_frame_path = os.path.join(path, subdir) + os.makedirs(single_frame_path, exist_ok=True) + for i in range(len(imgs)): + save_image( + imgs[i], + os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), + normalize=True, + ) + os.chmod( + os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), + stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU, + ) + save_image( + torch.cat(imgs, dim=0), + os.path.join(single_frame_path, f"frame_concat.jpg"), + normalize=True, + ) + os.chmod( + os.path.join(single_frame_path, f"frame_concat.jpg"), + stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU, + ) + + +def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len): + # The fisrt token's position id of the frame that the next token belongs to; + if total_len < text_len: + return None + return (total_len - text_len) // frame_len * frame_len + text_len + + +def my_filling_sequence( + model, + args, + seq, + batch_size, + get_masks_and_position_ids, + text_len, + frame_len, + strategy=BaseStrategy(), + strategy2=BaseStrategy(), + mems=None, + log_text_attention_weights=0, # default to 0: no artificial change + mode_stage1=True, + enforce_no_swin=False, + guider_seq=None, + guider_text_len=0, + guidance_alpha=1, + limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内 + **kw_args, +): + """ + seq: [2, 3, 5, ..., -1(to be generated), -1, ...] + mems: [num_layers, batch_size, len_mems(index), mem_hidden_size] + cache, should be first mems.shape[1] parts of context_tokens. + mems are the first-level citizens here, but we don't assume what is memorized. + input mems are used when multi-phase generation. + """ + if guider_seq is not None: + logging.debug("Using Guidance In Inference") + if limited_spatial_channel_mem: + logging.debug("Limit spatial-channel's mem to current frame") + assert len(seq.shape) == 2 + + # building the initial tokens, attention_mask, and position_ids + actual_context_length = 0 + + while seq[-1][actual_context_length] >= 0: # the last seq has least given tokens + actual_context_length += 1 # [0, context_length-1] are given + assert actual_context_length > 0 + current_frame_num = (actual_context_length - text_len) // frame_len + assert current_frame_num >= 0 + context_length = text_len + current_frame_num * frame_len + + tokens, attention_mask, position_ids = get_masks_and_position_ids( + seq, text_len, frame_len + ) + tokens = tokens[..., :context_length] + input_tokens = tokens.clone() + + if guider_seq is not None: + guider_index_delta = text_len - guider_text_len + guider_tokens, guider_attention_mask, guider_position_ids = ( + get_masks_and_position_ids(guider_seq, guider_text_len, frame_len) + ) + guider_tokens = guider_tokens[..., : context_length - guider_index_delta] + guider_input_tokens = guider_tokens.clone() + + for fid in range(current_frame_num): + input_tokens[:, text_len + 400 * fid] = tokenizer[""] + if guider_seq is not None: + guider_input_tokens[:, guider_text_len + 400 * fid] = tokenizer[ + "" + ] + + attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 + # initialize generation + counter = context_length - 1 # Last fixed index is ``counter'' + index = 0 # Next forward starting index, also the length of cache. + mems_buffers_on_GPU = False + mems_indexs = [0, 0] + mems_len = [ + (400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74, + 5 * 400 + 74, + ] + mems_buffers = [ + torch.zeros( + args.num_layers, + batch_size, + mem_len, + args.hidden_size * 2, + dtype=next(model.parameters()).dtype, + ) + for mem_len in mems_len + ] + + if guider_seq is not None: + guider_attention_mask = guider_attention_mask.type_as( + next(model.parameters()) + ) # if fp16 + guider_mems_buffers = [ + torch.zeros( + args.num_layers, + batch_size, + mem_len, + args.hidden_size * 2, + dtype=next(model.parameters()).dtype, + ) + for mem_len in mems_len + ] + guider_mems_indexs = [0, 0] + guider_mems = None + + torch.cuda.empty_cache() + # step-by-step generation + while counter < len(seq[0]) - 1: + # we have generated counter+1 tokens + # Now, we want to generate seq[counter + 1], + # token[:, index: counter+1] needs forwarding. + if index == 0: + group_size = ( + 2 + if (input_tokens.shape[0] == batch_size and not mode_stage1) + else batch_size + ) + + logits_all = None + for batch_idx in range(0, input_tokens.shape[0], group_size): + logits, *output_per_layers = model( + input_tokens[batch_idx : batch_idx + group_size, index:], + position_ids[..., index : counter + 1], + attention_mask, # TODO memlen + mems=mems, + text_len=text_len, + frame_len=frame_len, + counter=counter, + log_text_attention_weights=log_text_attention_weights, + enforce_no_swin=enforce_no_swin, + **kw_args, + ) + logits_all = ( + torch.cat((logits_all, logits), dim=0) + if logits_all is not None + else logits + ) + mem_kv01 = [ + [o["mem_kv"][0] for o in output_per_layers], + [o["mem_kv"][1] for o in output_per_layers], + ] + next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id( + text_len, frame_len, mem_kv01[0][0].shape[1] + ) + for id, mem_kv in enumerate(mem_kv01): + for layer, mem_kv_perlayer in enumerate(mem_kv): + if limited_spatial_channel_mem and id == 0: + mems_buffers[id][ + layer, batch_idx : batch_idx + group_size, :text_len + ] = mem_kv_perlayer.expand( + min(group_size, input_tokens.shape[0] - batch_idx), + -1, + -1, + )[ + :, :text_len + ] + mems_buffers[id][ + layer, + batch_idx : batch_idx + group_size, + text_len : text_len + + mem_kv_perlayer.shape[1] + - next_tokens_frame_begin_id, + ] = mem_kv_perlayer.expand( + min(group_size, input_tokens.shape[0] - batch_idx), + -1, + -1, + )[ + :, next_tokens_frame_begin_id: + ] + else: + mems_buffers[id][ + layer, + batch_idx : batch_idx + group_size, + : mem_kv_perlayer.shape[1], + ] = mem_kv_perlayer.expand( + min(group_size, input_tokens.shape[0] - batch_idx), + -1, + -1, + ) + mems_indexs[0], mems_indexs[1] = ( + mem_kv01[0][0].shape[1], + mem_kv01[1][0].shape[1], + ) + if limited_spatial_channel_mem: + mems_indexs[0] -= next_tokens_frame_begin_id - text_len + + mems = [mems_buffers[id][:, :, : mems_indexs[id]] for id in range(2)] + logits = logits_all + + # Guider + if guider_seq is not None: + guider_logits_all = None + for batch_idx in range(0, guider_input_tokens.shape[0], group_size): + guider_logits, *guider_output_per_layers = model( + guider_input_tokens[ + batch_idx : batch_idx + group_size, + max(index - guider_index_delta, 0) :, + ], + guider_position_ids[ + ..., + max(index - guider_index_delta, 0) : counter + + 1 + - guider_index_delta, + ], + guider_attention_mask, + mems=guider_mems, + text_len=guider_text_len, + frame_len=frame_len, + counter=counter - guider_index_delta, + log_text_attention_weights=log_text_attention_weights, + enforce_no_swin=enforce_no_swin, + **kw_args, + ) + guider_logits_all = ( + torch.cat((guider_logits_all, guider_logits), dim=0) + if guider_logits_all is not None + else guider_logits + ) + guider_mem_kv01 = [ + [o["mem_kv"][0] for o in guider_output_per_layers], + [o["mem_kv"][1] for o in guider_output_per_layers], + ] + for id, guider_mem_kv in enumerate(guider_mem_kv01): + for layer, guider_mem_kv_perlayer in enumerate(guider_mem_kv): + if limited_spatial_channel_mem and id == 0: + guider_mems_buffers[id][ + layer, + batch_idx : batch_idx + group_size, + :guider_text_len, + ] = guider_mem_kv_perlayer.expand( + min(group_size, input_tokens.shape[0] - batch_idx), + -1, + -1, + )[ + :, :guider_text_len + ] + guider_next_tokens_frame_begin_id = ( + calc_next_tokens_frame_begin_id( + guider_text_len, + frame_len, + guider_mem_kv_perlayer.shape[1], + ) + ) + guider_mems_buffers[id][ + layer, + batch_idx : batch_idx + group_size, + guider_text_len : guider_text_len + + guider_mem_kv_perlayer.shape[1] + - guider_next_tokens_frame_begin_id, + ] = guider_mem_kv_perlayer.expand( + min(group_size, input_tokens.shape[0] - batch_idx), + -1, + -1, + )[ + :, guider_next_tokens_frame_begin_id: + ] + else: + guider_mems_buffers[id][ + layer, + batch_idx : batch_idx + group_size, + : guider_mem_kv_perlayer.shape[1], + ] = guider_mem_kv_perlayer.expand( + min(group_size, input_tokens.shape[0] - batch_idx), + -1, + -1, + ) + guider_mems_indexs[0], guider_mems_indexs[1] = ( + guider_mem_kv01[0][0].shape[1], + guider_mem_kv01[1][0].shape[1], + ) + if limited_spatial_channel_mem: + guider_mems_indexs[0] -= ( + guider_next_tokens_frame_begin_id - guider_text_len + ) + guider_mems = [ + guider_mems_buffers[id][:, :, : guider_mems_indexs[id]] + for id in range(2) + ] + guider_logits = guider_logits_all + else: + if not mems_buffers_on_GPU: + if not mode_stage1: + torch.cuda.empty_cache() + for idx, mem in enumerate(mems): + mems[idx] = mem.to(next(model.parameters()).device) + if guider_seq is not None: + for idx, mem in enumerate(guider_mems): + guider_mems[idx] = mem.to(next(model.parameters()).device) + else: + torch.cuda.empty_cache() + for idx, mem_buffer in enumerate(mems_buffers): + mems_buffers[idx] = mem_buffer.to( + next(model.parameters()).device + ) + mems = [ + mems_buffers[id][:, :, : mems_indexs[id]] for id in range(2) + ] + if guider_seq is not None: + for idx, guider_mem_buffer in enumerate(guider_mems_buffers): + guider_mems_buffers[idx] = guider_mem_buffer.to( + next(model.parameters()).device + ) + guider_mems = [ + guider_mems_buffers[id][:, :, : guider_mems_indexs[id]] + for id in range(2) + ] + mems_buffers_on_GPU = True + + logits, *output_per_layers = model( + input_tokens[:, index:], + position_ids[..., index : counter + 1], + attention_mask, # TODO memlen + mems=mems, + text_len=text_len, + frame_len=frame_len, + counter=counter, + log_text_attention_weights=log_text_attention_weights, + enforce_no_swin=enforce_no_swin, + limited_spatial_channel_mem=limited_spatial_channel_mem, + **kw_args, + ) + mem_kv0, mem_kv1 = [o["mem_kv"][0] for o in output_per_layers], [ + o["mem_kv"][1] for o in output_per_layers + ] + + if guider_seq is not None: + guider_logits, *guider_output_per_layers = model( + guider_input_tokens[:, max(index - guider_index_delta, 0) :], + guider_position_ids[ + ..., + max(index - guider_index_delta, 0) : counter + + 1 + - guider_index_delta, + ], + guider_attention_mask, + mems=guider_mems, + text_len=guider_text_len, + frame_len=frame_len, + counter=counter - guider_index_delta, + log_text_attention_weights=0, + enforce_no_swin=enforce_no_swin, + limited_spatial_channel_mem=limited_spatial_channel_mem, + **kw_args, + ) + guider_mem_kv0, guider_mem_kv1 = [ + o["mem_kv"][0] for o in guider_output_per_layers + ], [o["mem_kv"][1] for o in guider_output_per_layers] + + if not mems_buffers_on_GPU: + torch.cuda.empty_cache() + for idx, mem_buffer in enumerate(mems_buffers): + mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device) + if guider_seq is not None: + for idx, guider_mem_buffer in enumerate(guider_mems_buffers): + guider_mems_buffers[idx] = guider_mem_buffer.to( + next(model.parameters()).device + ) + mems_buffers_on_GPU = True + + mems, mems_indexs = my_update_mems( + [mem_kv0, mem_kv1], + mems_buffers, + mems_indexs, + limited_spatial_channel_mem, + text_len, + frame_len, + ) + if guider_seq is not None: + guider_mems, guider_mems_indexs = my_update_mems( + [guider_mem_kv0, guider_mem_kv1], + guider_mems_buffers, + guider_mems_indexs, + limited_spatial_channel_mem, + guider_text_len, + frame_len, + ) + + counter += 1 + index = counter + + logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size] + tokens = tokens.expand(batch_size, -1) + if guider_seq is not None: + guider_logits = guider_logits[:, -1].expand(batch_size, -1) + guider_tokens = guider_tokens.expand(batch_size, -1) + + if seq[-1][counter].item() < 0: + # sampling + guided_logits = ( + guider_logits + (logits - guider_logits) * guidance_alpha + if guider_seq is not None + else logits + ) + if mode_stage1 and counter < text_len + 400: + tokens, mems = strategy.forward(guided_logits, tokens, mems) + else: + tokens, mems = strategy2.forward(guided_logits, tokens, mems) + if guider_seq is not None: + guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]), dim=1) + + if seq[0][counter].item() >= 0: + for si in range(seq.shape[0]): + if seq[si][counter].item() >= 0: + tokens[si, -1] = seq[si, counter] + if guider_seq is not None: + guider_tokens[si, -1] = guider_seq[ + si, counter - guider_index_delta + ] + + else: + tokens = torch.cat( + ( + tokens, + seq[:, counter : counter + 1] + .clone() + .expand(tokens.shape[0], 1) + .to(device=tokens.device, dtype=tokens.dtype), + ), + dim=1, + ) + if guider_seq is not None: + guider_tokens = torch.cat( + ( + guider_tokens, + guider_seq[ + :, + counter + - guider_index_delta : counter + + 1 + - guider_index_delta, + ] + .clone() + .expand(guider_tokens.shape[0], 1) + .to(device=guider_tokens.device, dtype=guider_tokens.dtype), + ), + dim=1, + ) + + input_tokens = tokens.clone() + if guider_seq is not None: + guider_input_tokens = guider_tokens.clone() + if (index - text_len - 1) // 400 < ( + input_tokens.shape[-1] - text_len - 1 + ) // 400: + boi_idx = ((index - text_len - 1) // 400 + 1) * 400 + text_len + while boi_idx < input_tokens.shape[-1]: + input_tokens[:, boi_idx] = tokenizer[""] + if guider_seq is not None: + guider_input_tokens[:, boi_idx - guider_index_delta] = tokenizer[ + "" + ] + boi_idx += 400 + + if strategy.is_done: + break + return strategy.finalize(tokens, mems) + + +class InferenceModel_Sequential(CogVideoCacheModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__( + args, + transformer=transformer, + parallel_output=parallel_output, + window_size=-1, + cogvideo_stage=1, + ) + + # TODO: check it + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear( + logits_parallel.float(), + self.transformer.word_embeddings.weight[:20000].float(), + ) + return logits_parallel + + +class InferenceModel_Interpolate(CogVideoCacheModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__( + args, + transformer=transformer, + parallel_output=parallel_output, + window_size=10, + cogvideo_stage=2, + ) + + # TODO: check it + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear( + logits_parallel.float(), + self.transformer.word_embeddings.weight[:20000].float(), + ) + return logits_parallel + + +def main(args): + assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1 + rank_id = args.device % args.parallel_size + generate_frame_num = args.generate_frame_num + + if args.stage_1 or args.both_stages: + model_stage1, args = InferenceModel_Sequential.from_pretrained( + args, "cogvideo-stage1" + ) + model_stage1.eval() + if args.both_stages: + model_stage1 = model_stage1.cpu() + + if args.stage_2 or args.both_stages: + model_stage2, args = InferenceModel_Interpolate.from_pretrained( + args, "cogvideo-stage2" + ) + model_stage2.eval() + if args.both_stages: + model_stage2 = model_stage2.cpu() + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + strategy_cogview2 = CoglmStrategy(invalid_slices, temperature=1.0, top_k=16) + strategy_cogvideo = CoglmStrategy( + invalid_slices, + temperature=args.temperature, + top_k=args.top_k, + temperature2=args.coglm_temperature2, + ) + if not args.stage_1: + from sr_pipeline import DirectSuperResolution + + dsr_path = auto_create( + "cogview2-dsr", path=None + ) # path=os.getenv('SAT_HOME', '~/.sat_models') + dsr = DirectSuperResolution(args, dsr_path, max_bz=12, onCUDA=False) + + def process_stage2( + model, + seq_text, + duration, + video_raw_text=None, + video_guidance_text="视频", + parent_given_tokens=None, + conddir=None, + outputdir=None, + gpu_rank=0, + gpu_parallel_size=1, + ): + stage2_starttime = time.time() + use_guidance = args.use_guidance_stage2 + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage-2 model to cuda") + model = model.cuda() + logging.debug( + "moving in stage-2 model takes time: {:.2f}".format( + time.time() - move_start_time + ) + ) + + try: + if parent_given_tokens is None: + assert conddir is not None + parent_given_tokens = torch.load( + os.path.join(conddir, "frame_tokens.pt"), map_location="cpu" + ) + sample_num_allgpu = parent_given_tokens.shape[0] + sample_num = sample_num_allgpu // gpu_parallel_size + assert sample_num * gpu_parallel_size == sample_num_allgpu + parent_given_tokens = parent_given_tokens[ + gpu_rank * sample_num : (gpu_rank + 1) * sample_num + ] + except: + logging.critical("No frame_tokens found in interpolation, skip") + return False + + # CogVideo Stage2 Generation + while ( + duration >= 0.5 + ): # TODO: You can change the boundary to change the frame rate + parent_given_tokens_num = parent_given_tokens.shape[1] + generate_batchsize_persample = (parent_given_tokens_num - 1) // 2 + generate_batchsize_total = generate_batchsize_persample * sample_num + total_frames = generate_frame_num + frame_len = 400 + enc_text = tokenizer.encode(seq_text) + enc_duration = tokenizer.encode(str(float(duration)) + "秒") + seq = ( + enc_duration + + [tokenizer[""]] + + enc_text + + [tokenizer[""]] + + [-1] * 400 * generate_frame_num + ) + text_len = len(seq) - frame_len * generate_frame_num - 1 + + logging.info( + "[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format( + int(4 / duration), tokenizer.decode(enc_text) + ) + ) + + # generation + seq = ( + torch.cuda.LongTensor(seq, device=args.device) + .unsqueeze(0) + .repeat(generate_batchsize_total, 1) + ) + for sample_i in range(sample_num): + for i in range(generate_batchsize_persample): + seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 : text_len + 1 + 400 + ] = parent_given_tokens[sample_i][2 * i] + seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 + 400 : text_len + 1 + 800 + ] = parent_given_tokens[sample_i][2 * i + 1] + seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 + 800 : text_len + 1 + 1200 + ] = parent_given_tokens[sample_i][2 * i + 2] + + if use_guidance: + guider_seq = ( + enc_duration + + [tokenizer[""]] + + tokenizer.encode(video_guidance_text) + + [tokenizer[""]] + + [-1] * 400 * generate_frame_num + ) + guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1 + guider_seq = ( + torch.cuda.LongTensor(guider_seq, device=args.device) + .unsqueeze(0) + .repeat(generate_batchsize_total, 1) + ) + for sample_i in range(sample_num): + for i in range(generate_batchsize_persample): + guider_seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 : text_len + 1 + 400 + ] = parent_given_tokens[sample_i][2 * i] + guider_seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 + 400 : text_len + 1 + 800 + ] = parent_given_tokens[sample_i][2 * i + 1] + guider_seq[sample_i * generate_batchsize_persample + i][ + text_len + 1 + 800 : text_len + 1 + 1200 + ] = parent_given_tokens[sample_i][2 * i + 2] + video_log_text_attention_weights = 0 + else: + guider_seq = None + guider_text_len = 0 + video_log_text_attention_weights = 1.4 + + mbz = args.max_inference_batch_size + + assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0 + output_list = [] + start_time = time.time() + for tim in range(max(generate_batchsize_total // mbz, 1)): + input_seq = ( + seq[: min(generate_batchsize_total, mbz)].clone() + if tim == 0 + else seq[mbz * tim : mbz * (tim + 1)].clone() + ) + guider_seq2 = ( + ( + guider_seq[: min(generate_batchsize_total, mbz)].clone() + if tim == 0 + else guider_seq[mbz * tim : mbz * (tim + 1)].clone() + ) + if guider_seq is not None + else None + ) + output_list.append( + my_filling_sequence( + model, + args, + input_seq, + batch_size=min(generate_batchsize_total, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage2, + text_len=text_len, + frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=video_log_text_attention_weights, + mode_stage1=False, + guider_seq=guider_seq2, + guider_text_len=guider_text_len, + guidance_alpha=args.guidance_alpha, + limited_spatial_channel_mem=True, + )[0] + ) + logging.info( + "Duration {:.2f}, Taken time {:.2f}\n".format( + duration, time.time() - start_time + ) + ) + + output_tokens = torch.cat(output_list, dim=0) + output_tokens = output_tokens[ + :, text_len + 1 : text_len + 1 + (total_frames) * 400 + ].reshape(sample_num, -1, 400 * total_frames) + output_tokens_merge = torch.cat( + ( + output_tokens[:, :, : 1 * 400], + output_tokens[:, :, 400 * 3 : 4 * 400], + output_tokens[:, :, 400 * 1 : 2 * 400], + output_tokens[:, :, 400 * 4 : (total_frames) * 400], + ), + dim=2, + ).reshape(sample_num, -1, 400) + + output_tokens_merge = torch.cat( + (output_tokens_merge, output_tokens[:, -1:, 400 * 2 : 3 * 400]), dim=1 + ) + duration /= 2 + parent_given_tokens = output_tokens_merge + + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 2 model to cpu") + model = model.cpu() + torch.cuda.empty_cache() + logging.debug( + "moving out model2 takes time: {:.2f}".format( + time.time() - move_start_time + ) + ) + + logging.info( + "CogVideo Stage2 completed. Taken time {:.2f}\n".format( + time.time() - stage2_starttime + ) + ) + + # decoding + # imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge] + # os.makedirs(output_dir_full_path, exist_ok=True) + # my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False) + # torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt')) + # os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2") + + # direct super-resolution by CogView2 + logging.info("[Direct super-resolution]") + dsr_starttime = time.time() + enc_text = tokenizer.encode(seq_text) + frame_num_per_sample = parent_given_tokens.shape[1] + parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400) + text_seq = ( + torch.cuda.LongTensor(enc_text, device=args.device) + .unsqueeze(0) + .repeat(parent_given_tokens_2d.shape[0], 1) + ) + sred_tokens = dsr(text_seq, parent_given_tokens_2d) + decoded_sr_videos = [] + + for sample_i in range(sample_num): + decoded_sr_imgs = [] + for frame_i in range(frame_num_per_sample): + decoded_sr_img = tokenizer.decode( + image_ids=sred_tokens[frame_i + sample_i * frame_num_per_sample][ + -3600: + ] + ) + decoded_sr_imgs.append( + torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480)) + ) + decoded_sr_videos.append(decoded_sr_imgs) + + for sample_i in range(sample_num): + my_save_multiple_images( + decoded_sr_videos[sample_i], + outputdir, + subdir=f"frames/{sample_i+sample_num*gpu_rank}", + debug=False, + ) + os.system( + f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125" + ) + + logging.info( + "Direct super-resolution completed. Taken time {:.2f}\n".format( + time.time() - dsr_starttime + ) + ) + + return True + + def process_stage1( + model, + seq_text, + duration, + video_raw_text=None, + video_guidance_text="视频", + image_text_suffix="", + outputdir=None, + batch_size=1, + ): + process_start_time = time.time() + use_guide = args.use_guidance_stage1 + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 1 model to cuda") + model = model.cuda() + logging.debug( + "moving in model1 takes time: {:.2f}".format( + time.time() - move_start_time + ) + ) + + if video_raw_text is None: + video_raw_text = seq_text + mbz = ( + args.stage1_max_inference_batch_size + if args.stage1_max_inference_batch_size > 0 + else args.max_inference_batch_size + ) + assert batch_size < mbz or batch_size % mbz == 0 + frame_len = 400 + + # generate the first frame: + enc_text = tokenizer.encode(seq_text + image_text_suffix) + seq_1st = ( + enc_text + [tokenizer[""]] + [-1] * 400 + ) # IV!! # test local!!! # test randboi!!! + logging.info( + "[Generating First Frame with CogView2]Raw text: {:s}".format( + tokenizer.decode(enc_text) + ) + ) + text_len_1st = len(seq_1st) - frame_len * 1 - 1 + + seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0) + output_list_1st = [] + for tim in range(max(batch_size // mbz, 1)): + start_time = time.time() + output_list_1st.append( + my_filling_sequence( + model, + args, + seq_1st.clone(), + batch_size=min(batch_size, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage1, + text_len=text_len_1st, + frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=1.4, + enforce_no_swin=True, + mode_stage1=True, + )[0] + ) + logging.info( + "[First Frame]Taken time {:.2f}\n".format(time.time() - start_time) + ) + output_tokens_1st = torch.cat(output_list_1st, dim=0) + given_tokens = output_tokens_1st[ + :, text_len_1st + 1 : text_len_1st + 401 + ].unsqueeze( + 1 + ) # given_tokens.shape: [bs, frame_num, 400] + + # generate subsequent frames: + total_frames = generate_frame_num + enc_duration = tokenizer.encode(str(float(duration)) + "秒") + if use_guide: + video_raw_text = video_raw_text + " 视频" + enc_text_video = tokenizer.encode(video_raw_text) + seq = ( + enc_duration + + [tokenizer[""]] + + enc_text_video + + [tokenizer[""]] + + [-1] * 400 * generate_frame_num + ) + guider_seq = ( + enc_duration + + [tokenizer[""]] + + tokenizer.encode(video_guidance_text) + + [tokenizer[""]] + + [-1] * 400 * generate_frame_num + ) + logging.info( + "[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format( + 4 / duration, tokenizer.decode(enc_text_video) + ) + ) + + text_len = len(seq) - frame_len * generate_frame_num - 1 + guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1 + seq = ( + torch.cuda.LongTensor(seq, device=args.device) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + guider_seq = ( + torch.cuda.LongTensor(guider_seq, device=args.device) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + + for given_frame_id in range(given_tokens.shape[1]): + seq[ + :, + text_len + + 1 + + given_frame_id * 400 : text_len + + 1 + + (given_frame_id + 1) * 400, + ] = given_tokens[:, given_frame_id] + guider_seq[ + :, + guider_text_len + + 1 + + given_frame_id * 400 : guider_text_len + + 1 + + (given_frame_id + 1) * 400, + ] = given_tokens[:, given_frame_id] + output_list = [] + + if use_guide: + video_log_text_attention_weights = 0 + else: + guider_seq = None + video_log_text_attention_weights = 1.4 + + for tim in range(max(batch_size // mbz, 1)): + start_time = time.time() + input_seq = ( + seq[: min(batch_size, mbz)].clone() + if tim == 0 + else seq[mbz * tim : mbz * (tim + 1)].clone() + ) + guider_seq2 = ( + ( + guider_seq[: min(batch_size, mbz)].clone() + if tim == 0 + else guider_seq[mbz * tim : mbz * (tim + 1)].clone() + ) + if guider_seq is not None + else None + ) + output_list.append( + my_filling_sequence( + model, + args, + input_seq, + batch_size=min(batch_size, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage1, + text_len=text_len, + frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=video_log_text_attention_weights, + guider_seq=guider_seq2, + guider_text_len=guider_text_len, + guidance_alpha=args.guidance_alpha, + limited_spatial_channel_mem=True, + mode_stage1=True, + )[0] + ) + + output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len :] + + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 1 model to cpu") + model = model.cpu() + torch.cuda.empty_cache() + logging.debug( + "moving in model1 takes time: {:.2f}".format( + time.time() - move_start_time + ) + ) + + # decoding + imgs, sred_imgs, txts = [], [], [] + for seq in output_tokens: + decoded_imgs = [ + torch.nn.functional.interpolate( + tokenizer.decode(image_ids=seq.tolist()[i * 400 : (i + 1) * 400]), + size=(480, 480), + ) + for i in range(total_frames) + ] + imgs.append(decoded_imgs) # only the last image (target) + + assert len(imgs) == batch_size + save_tokens = ( + output_tokens[:, : +total_frames * 400].reshape(-1, total_frames, 400).cpu() + ) + if outputdir is not None: + for clip_i in range(len(imgs)): + # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True) + my_save_multiple_images( + imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False + ) + os.system( + f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25" + ) + torch.save(save_tokens, os.path.join(outputdir, "frame_tokens.pt")) + + logging.info( + "CogVideo Stage1 completed. Taken time {:.2f}\n".format( + time.time() - process_start_time + ) + ) + + return save_tokens + + # ====================================================================================================== + + if args.stage_1 or args.both_stages: + if args.input_source != "interactive": + with open(args.input_source, "r") as fin: + promptlist = fin.readlines() + promptlist = [p.strip() for p in promptlist] + else: + promptlist = None + + now_qi = -1 + while True: + now_qi += 1 + + if promptlist is not None: # with input-source + if args.multi_gpu: + if now_qi % dist.get_world_size() != dist.get_rank(): + continue + rk = dist.get_rank() + else: + rk = 0 + raw_text = promptlist[now_qi] + raw_text = raw_text.strip() + print(f"Working on Line No. {now_qi} on {rk}... [{raw_text}]") + else: # interactive + raw_text = input("\nPlease Input Query (stop to exit) >>> ") + raw_text = raw_text.strip() + if not raw_text: + print("Query should not be empty!") + continue + if raw_text == "stop": + return + + try: + path = os.path.join(args.output_path, f"{now_qi}_{raw_text}") + parent_given_tokens = process_stage1( + model_stage1, + raw_text, + duration=4.0, + video_raw_text=raw_text, + video_guidance_text="视频", + image_text_suffix=" 高清摄影", + outputdir=path if args.stage_1 else None, + batch_size=args.batch_size, + ) + if args.both_stages: + process_stage2( + model_stage2, + raw_text, + duration=2.0, + video_raw_text=raw_text + " 视频", + video_guidance_text="视频", + parent_given_tokens=parent_given_tokens, + outputdir=path, + gpu_rank=0, + gpu_parallel_size=1, + ) # TODO: 修改 + except (ValueError, FileNotFoundError) as e: + print(e) + continue + + elif args.stage_2: + sample_dirs = os.listdir(args.output_path) + for sample in sample_dirs: + raw_text = sample.split("_")[-1] + path = os.path.join(args.output_path, sample, "Interp") + parent_given_tokens = torch.load( + os.path.join(args.output_path, sample, "frame_tokens.pt") + ) + + process_stage2( + raw_text, + duration=2.0, + video_raw_text=raw_text + " 视频", + video_guidance_text="视频", + parent_given_tokens=parent_given_tokens, + outputdir=path, + gpu_rank=0, + gpu_parallel_size=1, + ) # TODO: 修改 + + else: + assert False + + +if __name__ == "__main__": + logging.basicConfig(stream=sys.stderr, level=logging.DEBUG) + + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument("--generate-frame-num", type=int, default=5) + py_parser.add_argument("--coglm-temperature2", type=float, default=0.89) + # py_parser.add_argument("--interp-duration", type=float, default=-1) # -1是顺序生成,0是超分,0.5/1/2是插帧 + # py_parser.add_argument("--total-duration", type=float, default=4.0) # 整个的时间 + py_parser.add_argument("--use-guidance-stage1", action="store_true") + py_parser.add_argument("--use-guidance-stage2", action="store_true") + py_parser.add_argument("--guidance-alpha", type=float, default=3.0) + py_parser.add_argument( + "--stage-1", action="store_true" + ) # stage 1: sequential generation + py_parser.add_argument("--stage-2", action="store_true") # stage 2: interp + dsr + py_parser.add_argument( + "--both-stages", action="store_true" + ) # stage 1&2: sequential generation; interp + dsr + py_parser.add_argument("--parallel-size", type=int, default=1) + py_parser.add_argument( + "--stage1-max-inference-batch-size", type=int, default=-1 + ) # -1: use max-inference-batch-size + py_parser.add_argument("--multi-gpu", action="store_true") + + CogVideoCacheModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + args.layout = [int(x) for x in args.layout.split(",")] + args.do_train = False + + torch.cuda.set_device(args.device) + + with torch.no_grad(): + main(args) diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/__init__.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_cache_model.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_cache_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ca391847f05ea1d85cc1f67daa6ee01c30ad7e04 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_cache_model.py @@ -0,0 +1,695 @@ +# -*- encoding: utf-8 -*- +''' +@File : cogvideo_cache_model.py +@Time : 2022/07/15 11:22:19 +@Author : Wenyi Hong +@Version : 1.0 +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib + +from multiprocessing import context +from tkinter import E +import torch +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim +from SwissArmyTransformer.model.transformer import unscaled_init_method +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear +import torch.nn.functional as F +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +import math + + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 912), + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + + +def window_partition(x, window_size): + """ + Args: + x: (B, framenum, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, frame_num, window_size, window_size, C) + """ + B, framenum, H, W, C = x.shape + x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, frame_num, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, frame_num, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + framenum = windows.shape[1] + x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1) + x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1) + return x + +class WindowAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + window_size, + shift_size, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + time_dim_attend_length=0 + ): + super(WindowAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense") + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.window_size = window_size + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + self.time_dim_attend_length = time_dim_attend_length + assert frame_resolution % window_size == 0 + assert 0 < shift_size < window_size + nW = (self.frame_resolution // self.window_size) ** 2 + ws_squre = self.window_size * self.window_size + + # odd non-shift, even shift + img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1)) + h_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, :, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size] + sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00)) + attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num) + attn_mask = attn_mask.tril() + + causal_mask = torch.ones(ws_squre*frame_num, ws_squre*frame_num) + causal_mask = causal_mask.tril() + + self.shift_sizes = [0, shift_size] + self.attn_mask = attn_mask + self.causal_mask = causal_mask + self.mask_initialized = False + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + + def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1): + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + if not self.mask_initialized: + self.attn_mask = self.attn_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.mask_initialized = True + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + if stage == 2: + assert frame_num == 3 + assert frame_num*frame_len == s1 + wind_square = self.window_size * self.window_size + nW = frame_len // wind_square + bswin = b0 * nW + + if memkv_text is not None: + s0 = memkv_text.shape[-2] + k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + + # shift + frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0) + if self.shift_sizes[layer_id%2] > 0: + frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3)) + # window partition + frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0) + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h] + q, k, v = qkv[0], qkv[1], qkv[2] + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + + if stage == 1: + if self.shift_sizes[layer_id%2] > 0: + attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), + self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))\ + - 10000.0 * (1.0 - self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0)) + attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square) + else: + attn = torch.mul(attn, self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))\ + - 10000.0 * (1.0 - self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0)) + + if memkv_text is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + else: + attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2)) + attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0) + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + + context_swin = (torch.matmul(attn[..., :-s0], v) + + torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\ + .reshape(bswin, self.n_head, frame_num*wind_square, h))\ + .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + + context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution) + + # reverse cycle shift + if self.shift_sizes[layer_id%2] > 0: + context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + ret_context = context_swin.reshape(b0, s1, h0) + + # for mem + memk = k.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + memv = v.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + memk = window_reverse(memk, self.window_size, self.frame_resolution, self.frame_resolution) + memv = window_reverse(memv, self.window_size, self.frame_resolution, self.frame_resolution) + if self.shift_sizes[layer_id%2] > 0: + memk = torch.roll(memk, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + memv = torch.roll(memv, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + memk, memv = memk.reshape(b0, s1, h0), memv.reshape(b0, s1, h0) + + ret_mem = torch.cat((memk, memv), dim=-1) + return ret_context, ret_mem + + def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1): + # frame_hidden_state [batchsize, 1, n_head*hiddensize_perhead] + # memkv [batchsize, pos, hidden_size*2] (include frames only) + # if memkv_text is not None: will attend to text + # pos: token's pos + b0, sin, h0 = frame_hidden_state.shape + h = h0 // self.n_head + assert sin == 1 + this_qkv = self.query_key_value[layer_id](frame_hidden_state) + thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:] + s1 = memkv.shape[1] if memkv is not None else 0 + frame_len = self.frame_resolution * self.frame_resolution + frame_num_before = s1 // frame_len + + + if memkv is not None: + pos_inframe = pos - frame_num_before * frame_len + + xpos = pos_inframe // self.frame_resolution # pos = xpos*self.frame_resolution + ypos + ypos = pos_inframe % self.frame_resolution + # [start, end) + if self.shift_sizes[layer_id%2] > 0: + xstart = ((xpos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2] + ystart = ((ypos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2] + xend = xstart + self.window_size + yend = ystart + self.window_size + xstart, ystart = max(0, xstart), max(0, ystart) + xend, yend = min(xend, self.frame_resolution), min(yend, self.frame_resolution) + else: + xstart = (xpos // self.window_size) * self.window_size + ystart = (ypos // self.window_size) * self.window_size + xend, yend = xstart + self.window_size, ystart+self.window_size + + # select index + selected_index = list() + if frame_num_before > 0: + # frames before + frame_attended_start = max(0, frame_num_before-self.time_dim_attend_length+1) if self.time_dim_attend_length > 0 else 0 + for x in range(xstart, xend): + for y in range(ystart, yend): + selected_index.append(x*self.frame_resolution+y+frame_len*frame_attended_start) + cnt_per_frame = len(selected_index) + for _ in range((frame_num_before-frame_attended_start-1)*cnt_per_frame): + selected_index.append(selected_index[-cnt_per_frame]+frame_len) + + # the last frame + for x in range(xstart, xend): + for y in range(ystart, yend): + tmppos = x*self.frame_resolution+y + frame_num_before * frame_len + if tmppos < pos: + selected_index.append(tmppos) + else: + break + cnt_all = len(selected_index)+1 + selected_index = torch.tensor(selected_index, device=memkv.device) + used_memkv = torch.index_select(memkv, 1, selected_index) + used_k, used_v = used_memkv[..., :h0], used_memkv[..., h0:] + used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2) + used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2) + if memkv_text is not None: + cnt_all += memkv_text.shape[-2] + used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2) + used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2) + used_k = used_k.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3) + else: + used_k = thisk + used_v = thisv + + if memkv_text is not None: + used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2) + used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2) + used_k = used_k.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3) + else: + used_k = used_k.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) + + thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h] + attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2)) + if memkv_text is not None: + attn[..., :memkv_text.shape[-2]] += log_text_attention_weights + attn = F.softmax(attn, dim=-1) + context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0) + + return context_swin, this_qkv[..., h0:] + +class FullAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + **kwargs, + ): + super(FullAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense") + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + + + def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1): + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + assert stage == 1 + + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + assert frame_num*frame_len == s1 + + if memkv_text is not None: + s0 = memkv_text.shape[-2] + k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h] + q, k, v = qkv[0], qkv[1], qkv[2] + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + attn = attn - 10000.0 * (1.0-torch.ones(b0, self.n_head, s1, s1, device=attn.device, dtype=attn.dtype).tril()) + + if memkv_text is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0) + else: + attn_frame2text = torch.matmul(q / math.sqrt(h), k_text.transpose(-1, -2)) #[b0, s1, s0] + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = (torch.matmul(attn[..., :-s0], v) + torch.matmul(attn[..., -s0:], v_text))\ + .permute(0, 2, 1, 3).reshape(b0, s1, h0) + + # for mem + memk = k.permute(0, 2, 1, 3).reshape(b0, s1, h0) + memv = v.permute(0, 2, 1, 3).reshape(b0, s1, h0) + ret_mem = torch.cat((memk, memv), dim=-1) + + return context_swin, ret_mem + + def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1): + # pos: current token's pos + b0, sin, h0 = frame_hidden_state.shape + h = h0 // self.n_head + assert sin == 1 + assert stage == 1 + + this_qkv = self.query_key_value[layer_id](frame_hidden_state) + thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:] + + if memkv is not None: + used_k, used_v = memkv[..., :h0], memkv[..., h0:] + used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2) + used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2) + else: + used_k, used_v = thisk, thisv + + if memkv_text is not None: + used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2) + used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2) + + used_k = used_k.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3) + thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h] + attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2)) + if memkv_text is not None: + attn[..., :memkv_text.shape[-2]] += log_text_attention_weights + attn = F.softmax(attn, dim=-1) + + context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0) + + return context_swin, this_qkv[..., h0:] + + +def attention_localframe_and_text_NAR(q0, k0, v0, attention_mask, + n_head, text_len, frame_len, frame_num, + attention_dropout=None, log_text_attention_weights=0, stage=1, **kwargs): + b, s0, h0 = q0.shape + s1 = s0 - text_len + h = h0 // n_head + assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num + # attention_mask.shape [4, b or 1, 1, text_len+frame_len, text_len+frame_len] + if stage == 2: + assert frame_num == 3 + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.transpose(-1, -2) + + score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len]) + score_any2text += log_text_attention_weights + score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask[..., :text_len, :text_len]) \ + - 10000.0 * (1.0 - attention_mask[..., :text_len, :text_len]) + # context for text + attention_probs_text = F.softmax(score_any2text_part1, dim=-1) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_text = attention_dropout(attention_probs_text) + context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :]) + context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0) + + if frame_num > 0: + score_any2text_part2 = score_any2text[..., text_len:, :] + + # score: frame local + q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2) + score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame) + if stage == 1: + score_frame_local0 = torch.mul(score_frame_local0, attention_mask[..., text_len:, text_len:].unsqueeze(1)) \ + - 10000.0 * (1.0 - attention_mask[..., text_len:, text_len:].unsqueeze(1)) + + # context for frame + score_frame_all = torch.cat((score_any2text_part2, + score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1) + attention_probs_frame = F.softmax(score_frame_all, dim=-1) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_frame = attention_dropout(attention_probs_frame) + context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h] + context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\ + view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h) + + context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0) + else: + context_frame = None + + return context_text2text, context_frame + +def attention_localframe_and_text_AR(q0, k0, v0, n_head, text_len, frame_len, frame_num, + attention_dropout=None, log_text_attention_weights=0, layer_id=None, limited_spatial_channel_mem=False, stage=1, **kwargs): + # limited_spatial_channel_mem=True means: mems in spatial channel is consisted of {mem_text, mem_current_frame} + b, s0, h0 = k0.shape + frame_num_before = (s0-text_len-1) // frame_len # frame_num == frame_num_before or frame_num == frame_num_before+1 + h = h0 // n_head + assert q0.shape[1] == 1 + assert v0.shape[1] == k0.shape[1] + + q0 = q0.reshape(b, 1, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + if limited_spatial_channel_mem: + assert frame_num_before == 0 + assert stage == 1 # not implemented for stage-2 yet + score = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + score[..., :text_len] += log_text_attention_weights + attention_probs_frame = F.softmax(score, dim=-1) + context_frame = torch.matmul(attention_probs_frame, v0).transpose(1, 2).reshape(b, 1, h0) + + else: + score_token2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len]) + score_token2text += log_text_attention_weights + score_frame_local0 = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., text_len+frame_num_before*frame_len:]) + score_frame_all = torch.cat((score_token2text, + score_frame_local0), dim=-1) + attention_probs_frame = F.softmax(score_frame_all, dim=-1) + + context_token2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h] + context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:], \ + v0[:, :, text_len+frame_num_before*frame_len:, :]) + context_frame = (context_token2text + context_frame_local0).transpose(1, 2).reshape(b, 1, h0) + + return context_frame + + +class CogVideoCacheModel(BaseModel): + def __init__(self, args, transformer=None, parallel_output=True, window_size=None, cogvideo_stage=None): + super().__init__(args, transformer=transformer, parallel_output=parallel_output) + self.layout = args.layout # [64, 64+1024, 64+6*1024] + self.stage = cogvideo_stage if cogvideo_stage is not None else args.cogvideo_stage # 1 or 2 + self.n_head = args.num_attention_heads + self.window_size = window_size if window_size is not None else args.window_size + + frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0])) + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + args.additional_seqlen, args.hidden_size + )) + + if self.stage == 1: + self.add_mixin('attention_plus', FullAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + n_head=args.num_attention_heads, + frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]), + )) + else: + self.add_mixin('attention_plus', WindowAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + window_size=self.window_size, + shift_size=self.window_size//2, + n_head=args.num_attention_heads, + frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]), + )) + + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('VideoSwinLocalModel', 'video swin local model configurations') + group.add_argument("--layout", type=str, default='64, 464, 2064') + group.add_argument("--window-size", type=int, default=10) # 优先级在直接参数赋值之后 + group.add_argument("--additional-seqlen", type=int, default=2000) + group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) # 优先级在直接参数赋值之后 + return parser + + def disable_untrainable_params(self): + pass + + def position_embedding_forward(self, position_ids, **kw_args): + if position_ids.shape[-1] > 1: + if self.stage == 1: + if position_ids[0,-1] >= (512+400): + frame_num = position_ids.shape[-1] // 400 + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position_ids[..., :-400*(frame_num-1)]), + self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -400*(frame_num-1):]-(512+400)) + ), + dim=-2 + ) + else: + position_embeddings = self.transformer.position_embeddings(position_ids) + else: + # given 3, interpolate 2 + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position_ids[..., :-800]), + self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -800:]-(512+400)) + ), + dim=-2 + ) + else: + if position_ids[0, 0] >= (512+400): + position_embeddings = self.get_mixin('extra_position_embedding').position_embeddings(position_ids-(512+400)) + else: + position_embeddings = self.transformer.position_embeddings(position_ids) + return position_embeddings + + def attention_forward(self, hidden_states, mask, layer_id, mems=None, log_text_attention_weights=0, text_len=0, frame_len=0, counter=0, enforce_no_swin=False, limited_spatial_channel_mem=False, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + hidden_size = hidden_states.shape[-1] + + # base model qkv + if mems is None: + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + assert (q0.shape[1]-text_len) % frame_len == 0 + memkv0 = torch.cat((k0, v0), dim=-1) + context_text, context_frame_local_text = attention_localframe_and_text_NAR( + q0, k0, v0, + mask, + n_head=attn_module.num_attention_heads_per_partition, + text_len=text_len, + frame_len=frame_len, + frame_num=(q0.shape[1]-text_len)//frame_len, + log_text_attention_weights=log_text_attention_weights, + stage=self.stage + ) + + # change: self.swin_attend_to_text默认为True: + memkv1_text = self.get_mixin('attention_plus').query_key_value[layer_id](hidden_states[..., :text_len, :])[..., hidden_size:] + output_text = attn_module.dense(context_text) + + if (q0.shape[1]-text_len)//frame_len > 0: + assert (q0.shape[1]-text_len) % frame_len == 0 + context_frame_swin, memkv1_frame = self.get_mixin('attention_plus').attention_extra_NAR_inference( + hidden_states[:,text_len:], layer_id, memkv_text=memkv1_text, stage=self.stage) + if not enforce_no_swin: + attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id]) + attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0) + output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\ + +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib) + else: + output_frame = attn_module.dense(context_frame_local_text[..., :frame_len, :]) + output = torch.cat((output_text, output_frame), dim=-2) + memkv1 = torch.cat((memkv1_text, memkv1_frame), dim=-2) if memkv1_text is not None else memkv1_frame + else: + output = output_text + memkv1 = memkv1_text + kw_args['output_this_layer']['mem_kv'] = (memkv0, memkv1) + + + else: + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + new_memkv0 = torch.cat((k0, v0), dim=-1) + old_k0, old_v0 = mems[0][layer_id][..., :hidden_size], mems[0][layer_id][..., hidden_size:] + + context_frame_local_text = attention_localframe_and_text_AR( + q0, + torch.cat((old_k0.expand(k0.shape[0], -1, -1), k0), dim=-2), + torch.cat((old_v0.expand(v0.shape[0], -1, -1), v0), dim=-2), + n_head=attn_module.num_attention_heads_per_partition, + text_len=text_len, + frame_len=frame_len, + frame_num=None, + log_text_attention_weights=log_text_attention_weights, + layer_id=layer_id, + limited_spatial_channel_mem=limited_spatial_channel_mem, + ) + + old_memkv1 = mems[1][layer_id] if mems[1] is not None else None + + context_frame_swin, new_memkv1 = self.get_mixin('attention_plus').attention_extra_AR_inference(hidden_states, + old_memkv1[..., text_len:, :] if old_memkv1.shape[-2]>text_len else None, + counter-text_len, + layer_id, + memkv_text=old_memkv1[..., :text_len, :], + log_text_attention_weights=log_text_attention_weights) + if not enforce_no_swin: + attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id]) + attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0) + output = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\ + +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib) + else: + output = attn_module.dense(context_frame_local_text) + + kw_args['output_this_layer']['mem_kv'] = (new_memkv0, new_memkv1) + + return output \ No newline at end of file diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_model.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dfbc136161f922de6a420cddaa2b1c3287a7eea8 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/models/cogvideo_model.py @@ -0,0 +1,543 @@ +# -*- encoding: utf-8 -*- +''' +@File : cogvideo_model.py +@Time : 2022/07/11 16:12:05 +@Author : Wenyi Hong +@Version : 1.0 +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib + +import torch +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim +from SwissArmyTransformer.model.transformer import unscaled_init_method +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear +import torch.nn.functional as F +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +import math + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 912), + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + +def window_partition(x, window_size): + """ + Args: + x: (B, framenum, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, frame_num, window_size, window_size, C) + """ + B, framenum, H, W, C = x.shape + x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, frame_num, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, frame_num, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + framenum = windows.shape[1] + x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1) + x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1) + return x + +class WindowAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + window_size, + shift_size, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + ): + super(WindowAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense", + ) + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.window_size = window_size + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + assert frame_resolution % window_size == 0 + assert 0 < shift_size < window_size + nW = (self.frame_resolution // self.window_size) ** 2 + ws_squre = self.window_size * self.window_size + + # odd non-shift, even shift + img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1)) + h_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, :, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size] + sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00)) + attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num) + + self.attn_mask_sequential = attn_mask.clone().tril() + self.causal_mask_sequential = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num).tril() + + self.causal_mask_interp = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num) + self.attn_mask_interp = attn_mask.clone() + + # bi-dir + for bi_idx in range(0, frame_num, 2): + for uni_idx in range(1, frame_num, 2): + self.attn_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0 + self.causal_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0 + # uni-dir + for uni_idx in range(1, frame_num, 2): + self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_() + self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_() + for uni_idx2 in range(uni_idx+2, frame_num, 2): + self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0 + self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0 + + # expand dim + self.attn_mask_sequential = self.attn_mask_sequential[None, None, :, None] + self.attn_mask_interp = self.attn_mask_interp[None, None, :, None] + self.causal_mask_sequential = self.causal_mask_sequential[None, None, :, None] + self.causal_mask_interp = self.causal_mask_interp[None, None, :, None] + + self.shift_sizes = [0, shift_size] + # self.register_buffer("attn_mask", attn_mask) + # self.register_buffer("causal_mask", causal_mask) + self.mask_initialized = False + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + + def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None, + text_attn_mask=None, mode_sequential=True): + # pb relax + swin_pb_relax = True + alpha = 16 + + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + if not self.mask_initialized: + self.attn_mask_sequential = self.attn_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.causal_mask_sequential = self.causal_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.attn_mask_interp = self.attn_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.causal_mask_interp = self.causal_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.mask_initialized = True + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + assert frame_num*frame_len == s1 + wind_square = self.window_size * self.window_size + nW = frame_len // wind_square + bswin = b0 * nW + + causal_mask = self.causal_mask_sequential if mode_sequential else self.causal_mask_interp + attn_mask = self.attn_mask_sequential if mode_sequential else self.attn_mask_interp + if text_hidden_state is not None: + s0 = text_hidden_state.shape[1] + qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h] + q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2] + + # shift + frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0) + if self.shift_sizes[layer_id%2] > 0: + frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3)) + # window partition + frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0) + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h] + q, k, v = qkv[0], qkv[1], qkv[2] + + # pb-relax + if swin_pb_relax: + attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2)) + else: + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + + if self.shift_sizes[layer_id%2] > 0: + # attn = attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square) + self.attn_mask.unsqueeze(1).unsqueeze(0) + attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), attn_mask)\ + - 10000.0 * (1.0 - attn_mask) + attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square) + else: + attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), causal_mask)\ + - 10000.0 * (1.0 - causal_mask) + attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square) + if swin_pb_relax: + swin_pb_relax_const = torch.max(attn.reshape(bswin, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1) + attn = (attn - swin_pb_relax_const)*alpha + + if text_hidden_state is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + else: + assert text_attn_mask is not None + text_attn_mask = text_attn_mask.unsqueeze(2).unsqueeze(2) + # pb-relax + if swin_pb_relax: + attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / (math.sqrt(h)*alpha), k_text.unsqueeze(1).transpose(-1, -2)) + attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, -1, self.n_head, 1, 1))*alpha + else: + attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2)) + + attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask) + attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0) + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + + context_swin = (torch.matmul(attn[..., :-s0], v) + + torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\ + .reshape(bswin, self.n_head, frame_num*wind_square, h))\ + .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + + context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution) + # reverse cycle shift + if self.shift_sizes[layer_id%2] > 0: + context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + context_swin = context_swin.reshape(b0, s1, h0) + + return context_swin + + +class FullAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + ): + super(FullAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense",) + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + self.causal_mask = torch.ones(1, 1, self.frame_len*frame_num, self.frame_len*frame_num).tril() + + self.mask_initialized = False + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + base_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(base_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(base_attention.query_key_value.bias.data) + + def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None, + text_attn_mask=None, mode_sequential=False): + # pb relax + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + assert mode_sequential == True # only + swin_pb_relax = True + alpha = 16 + + if not self.mask_initialized: + self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.mask_initialized = True + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + assert frame_num*frame_len == s1 + + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h] + q, k, v = qkv[0], qkv[1], qkv[2] + + # frames-to-frames + if swin_pb_relax: + attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2)) + else: + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + attn = torch.mul(attn, self.causal_mask) - 10000.0 * (1.0 - self.causal_mask) + if swin_pb_relax: + swin_pb_relax_const = torch.max(attn.reshape(b0, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1) + attn = (attn - swin_pb_relax_const)*alpha + + if text_hidden_state is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0) + else: + # frame-to-text + assert text_attn_mask is not None + s0 = text_hidden_state.shape[1] + qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h] + q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2] + text_attn_mask = text_attn_mask.unsqueeze(2) + if swin_pb_relax: + attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / (math.sqrt(h)*alpha), k_text.transpose(-1, -2)) + attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, self.n_head, 1, 1))*alpha + else: + attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / math.sqrt(h), k_text.transpose(-1, -2)) + attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask) + attn_frame2text = attn_frame2text.reshape(b0, self.n_head, s1, s0) + + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + + context_frame = (torch.matmul(attn[..., :-s0], v) + + torch.matmul(attn[..., -s0:].reshape(b0, self.n_head,s1, s0), v_text))\ + .permute(0, 2, 1, 3).reshape(b0, s1, h0) + + return context_frame + + +def attention_localframe_and_text(q0, k0, v0, attention_mask_totxt, attention_mask_local, + n_head, text_len, frame_len, frame_num, attention_dropout=None, layer_id=0, **kwargs): + b, s0, h0 = q0.shape + s1 = s0 - text_len + h = h0 // n_head + assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num + # attention_mask_totxt [b, 1, 1, text_len] + # attention_mask_local [1, 1, frame_num, frame_len, frame_len] + # attention_mask: [1, 1, text_len+frame_len, text_len+frame_len] + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.transpose(-1, -2) + + # score: any2text + score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len]) + score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask_totxt) \ + - 10000.0 * (1.0 - attention_mask_totxt) + score_any2text_part2 = torch.mul(score_any2text[..., text_len:, :], attention_mask_totxt) - \ + 10000.0 * (1.0 - attention_mask_totxt) + + # score: frame local + q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2) + score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame) + score_frame_local0 = torch.mul(score_frame_local0, attention_mask_local) \ + - 10000.0 * (1.0 - attention_mask_local) + + # context for frame + score_frame_all = torch.cat((score_any2text_part2, + score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1) + attention_probs_frame = F.softmax(score_frame_all, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_frame = attention_dropout(attention_probs_frame) + + context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h] + context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\ + view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h) + context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0) + + # context for text + attention_probs_text = F.softmax(score_any2text_part1, dim=-1) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_text = attention_dropout(attention_probs_text) + context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :]) + context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0) + + return context_text2text, context_frame + + +class CogVideoModel(BaseModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output) + self.stage = args.cogvideo_stage # 1 or 2 + self.mode_sequential = True if self.stage==1 else False + self.layout = args.layout # [64, 64+400, 64+5*400] + self.n_head = args.num_attention_heads + frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0])) + frame_num = (args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]) + frame_len = self.layout[1]-self.layout[0] + + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + args.additional_seqlen, args.hidden_size + )) + + if args.window_size == -1: + # full attention + assert self.stage == 1 + self.add_mixin('attention_plus', FullAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + n_head=args.num_attention_heads, + frame_num=frame_num, + )) + else: + self.add_mixin('attention_plus', WindowAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + window_size=args.window_size, + shift_size=args.window_size//2, + n_head=args.num_attention_heads, + frame_num=frame_num, + )) + # attention_mask_local + self.attention_mask_local_sequential = torch.ones(1, 1, frame_num, frame_len, frame_len).tril().unsqueeze(0) + self.attention_mask_local_interp = torch.ones(1, 1, frame_num, frame_len, frame_len) + + for idx in range(1, frame_num, 2): + self.attention_mask_local_interp[:, :, idx:idx+1].tril_() + self.attention_mask_local_interp = self.attention_mask_local_interp.unsqueeze(0) + self.mask_initialized = False + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('CogVideoModel', 'CogVideo model configurations') + group.add_argument("--layout", type=str, default='64, 464, 2064', help='text_len, textlen+frame_len, textlen+frame_len*frame_num') + group.add_argument("--window-size", type=int, default=10, help="swin attention's window size in temperal channel, -1 represents full attention") + group.add_argument("--additional-seqlen", type=int, default=2000) + group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) + return parser + + def disable_untrainable_params(self): + self.transformer.requires_grad_(False) + + def position_embedding_forward(self, position_ids, **kw_args): + position = position_ids[..., :(64+400)] + position_plus = position_ids[..., (64+400):] + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position), + self.get_mixin('extra_position_embedding').position_embeddings(position_plus-(512+400)) + ), + dim=-2 + ) + return position_embeddings + + def attention_forward(self, hidden_states, mask, layer_id, **kw_args): + # mask.shape=[bs, 1, 1, 64] + if not self.mask_initialized: + self.attention_mask_local_sequential = self.attention_mask_local_sequential.to(device=hidden_states.device, dtype=hidden_states.dtype) + self.attention_mask_local_interp = self.attention_mask_local_interp.to(device=hidden_states.device, dtype=hidden_states.dtype) + self.mask_initialized = True + + attn_module = self.transformer.layers[layer_id].attention + hidden_size = hidden_states.shape[-1] + bs = hidden_states.shape[0] + + # base model qkv + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + dropout_fn = self.transformer.layers[layer_id].attention.attention_dropout if self.training else None + + attention_mask_local = self.attention_mask_local_sequential if self.mode_sequential else self.attention_mask_local_interp + context_text, context_frame_local_text = attention_localframe_and_text( + q0, k0, v0, + attention_mask_totxt=mask, + attention_mask_local=attention_mask_local, + n_head=attn_module.num_attention_heads_per_partition, + text_len=self.layout[0], + frame_len=self.layout[1]-self.layout[0], + frame_num=(self.layout[2]-self.layout[0])//(self.layout[1]-self.layout[0]), + attention_dropout=dropout_fn, + layer_id=layer_id, + ) + + context_frame_swin = self.get_mixin('attention_plus').attention_extra( + hidden_states[:, self.layout[0]:], layer_id, dropout_fn, + text_hidden_state=hidden_states[:, :self.layout[0]], + text_attn_mask=mask[..., 0, :], + mode_sequential=self.mode_sequential) + + attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id]) + attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0) + + output_text = attn_module.dense(context_text) + output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\ + +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib) + output = torch.cat((output_text, output_frame), dim=-2) + + return output \ No newline at end of file diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/pretrain_cogvideo.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/pretrain_cogvideo.py new file mode 100644 index 0000000000000000000000000000000000000000..a64172d5f94578fb2dc6c4098d16fab0a464c2ec --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/pretrain_cogvideo.py @@ -0,0 +1,184 @@ +# -*- encoding: utf-8 -*- +''' +@File : pretrain_cogvideo.py +@Time : 2021/10/06 00:58:32 +@Author : Wenyi Hong +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import argparse +import numpy as np +from videogen_hub.depend.icetk import icetk as tokenizer +tokenizer.add_special_tokens(['', '', '']) + +from models.cogvideo_model import CogVideoModel +from SwissArmyTransformer import mpu, get_args +from SwissArmyTransformer.training.deepspeed_training import training_main +from SwissArmyTransformer.data_utils import BinaryDataset + +def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None): + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + assert attention_mask_totxt is not None + layout = args.layout + assert seq_length == layout[-1] + n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long() + frame_len = layout[1]-layout[0] + position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long, + device=data.device) + for i in range(batch_size): + torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]], + dtype=torch.long, device=data.device) + torch.arange(512, 512+layout[2]-layout[0], + out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device) + return position_ids + + +def get_batch(data_iterator, args, timers): + # Items and their type. + keys = ['text', 'loss_mask', 'attention_mask_totxt'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + tokens_ = data_b['text'].long() + loss_mask = data_b['loss_mask'].float() + attention_mask_totxt = data_b['attention_mask_totxt'].float() + + labels = tokens_[:, 1:].clone().contiguous() + loss_mask = loss_mask[:, 1:].contiguous() + tokens = tokens_[:, :-1].clone().contiguous() + + for idx in range(args.layout[0], args.layout[2], 400): + tokens[:, idx] = tokenizer[''] + # Get the masks and postition ids. + position_ids = get_masks_and_position_ids_video( + tokens, + attention_mask_totxt=attention_mask_totxt, + args=args + ) + attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1) + # Convert + if args.fp16: + attention_mask_totxt = attention_mask_totxt.half() + return tokens, labels, loss_mask, attention_mask_totxt, position_ids + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch( + data_iterator, args, timers) + timers('batch generator').stop() + + # Forward model. + logits, *mems = model(tokens, position_ids, attention_mask_totxt) + # ======= hyper params =======# + perframe_len = 400 + text_len=64 + frame_num = 5 + logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous() + losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:]) + # scaling loss mask + loss_mask = loss_mask[:, text_len:].reshape(-1) + + losses_1d = losses.reshape(-1) * loss_mask + loss = torch.sum(losses_1d) / loss_mask.sum() + # ===================== Log partial losses ======================== # + log_loss_dict = {} + bs = losses.shape[0] + + if args.cogvideo_stage == 1: + for i in range(frame_num): + log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) + else: + for i in range(1, frame_num-1): + log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) + + # ===================== END OF BLOCK ======================= # + return loss, log_loss_dict + + +def create_dataset_function(path, args): + dataset_layout = [64, 464, 2064] + input_layout = [64, 464, 2064] + # frame_num = 6 + # frame_interval = 2 # DEBUG!!! + def process_fn(row): + row = row.astype(np.int64) + text = row[:dataset_layout[0]] + frames = row[dataset_layout[0]:] + + if text[0] == tokenizer['']: + text = text[1:] # due to our way of data processing + if args.cogvideo_stage == 1: + text, loss_mask, frames = make_text_video_generation(text, frames) + else: + text, loss_mask, frames = mask_video_frame_interpolation(text, frames) + + n_pad = input_layout[0] - len(text) + parts = [ + np.array([tokenizer['']] * n_pad, dtype=np.int64), + text, + np.array([tokenizer['']], dtype=np.int64), + frames, + ] + ret = np.concatenate(parts, axis=0) + + attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad)) + return {'text': ret, + 'loss_mask': loss_mask, + 'attention_mask_totxt': attention_mask_totxt, + } + return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1]) + +def make_text_video_generation(text, frames): + input_layout = [64, 464, 2064] + text = text[text!= tokenizer['']][:input_layout[0]] # dataset format: 1.0秒{text} ... + loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位 + return text, loss_mask, frames + +def mask_video_frame_interpolation(text, frames): + input_layout = [64, 464, 2064] + frame_len = input_layout[1]-input_layout[0] + # text format: 1.0秒 {text} + text = text[text!= tokenizer['']][:input_layout[0]] + loss_mask = np.array([0] * (input_layout[1]+1) + + [1] * (input_layout[1]-input_layout[0]) + + [0] * (input_layout[1]-input_layout[0]) + + [1] * (input_layout[1]-input_layout[0]) + + [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位 + + return text, loss_mask, frames + + + +if __name__ == '__main__': + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--txt-loss-scale', type=float, default=1) + CogVideoModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + args.layout = [int(x) for x in args.layout.split(',')] + + training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function) diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/requirements.txt b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cf8f2099d2c38581695fa1436407910d54b57a80 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/requirements.txt @@ -0,0 +1,4 @@ +SwissArmyTransformer==0.2.9 +icetk +gifmaker +torchvision diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/__init__.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91bfa935fc366495bebd31f011ef2f59620d48f4 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/__init__.py @@ -0,0 +1,17 @@ +# -*- encoding: utf-8 -*- +''' +@File : __init__.py +@Time : 2022/03/02 13:57:09 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +from .direct_sr import DirectSuperResolution +from .iterative_sr import IterativeSuperResolution +from .sr_group import SRGroup diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/cluster_label2.npy b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/cluster_label2.npy new file mode 100644 index 0000000000000000000000000000000000000000..1c27ec8f73830ac8611789750dbfd73a2a494920 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/cluster_label2.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b87880fdbe89670f12844377b9cf97a9733b1f54e3a9b73cbb9835084c4e02ec +size 160128 diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/direct_sr.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/direct_sr.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2414479084d062f3332548ef8e98faa5cd7112 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/direct_sr.py @@ -0,0 +1,117 @@ +# -*- encoding: utf-8 -*- +''' +@File : direct_sr.py +@Time : 2022/03/02 13:58:11 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch + +# -*- encoding: utf-8 -*- +''' +@File : inference_cogview2.py +@Time : 2021/10/10 16:31:34 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +from PIL import ImageEnhance, Image + +import torch +import argparse +from torchvision import transforms + +from SwissArmyTransformer import get_args +from SwissArmyTransformer.training.model_io import load_checkpoint +from .dsr_sampling import filling_sequence_dsr, IterativeEntfilterStrategy +from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually + +from .dsr_model import DsrModel + +from videogen_hub.depend.icetk import icetk as tokenizer + +class DirectSuperResolution: + def __init__(self, args, path, max_bz=4, topk=6, onCUDA=False): + args.load = path + args.kernel_size = 5 + args.kernel_size2 = 5 + args.new_sequence_length = 4624 + args.layout = [96,496,4096] + + model = DsrModel(args) + if args.fp16: + model = model.half() + + load_checkpoint(model, args) # on cpu + model.eval() + self.model = model + self.onCUDA = onCUDA + if onCUDA: + self.model = self.model.cuda() + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + + self.strategy = IterativeEntfilterStrategy(invalid_slices, + temperature=1.0, topk=topk) # temperature not used # Temperature Freezed Here!! + self.max_bz = max_bz + + def __call__(self, text_tokens, image_tokens, enhance=False): + if len(text_tokens.shape) == 1: + text_tokens.unsqueeze_(0) + if len(image_tokens.shape) == 1: + image_tokens.unsqueeze_(0) + # ===================== Debug ======================== # + # new_image_tokens = [] + # for small_img in image_tokens: + # decoded = tokenizer.decode(image_ids=small_img) + # decoded = torch.nn.functional.interpolate(decoded, size=(480, 480)).squeeze(0) + # ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + # image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr)) + # small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1) + # new_image_tokens.append(small_img2) + # image_tokens = torch.stack(new_image_tokens) + # return image_tokens + # ===================== END OF BLOCK ======================= # + if enhance: + new_image_tokens = [] + for small_img in image_tokens: + decoded = tokenizer.decode(image_ids=small_img).squeeze(0) + ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr)) + small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.), image_size=160).view(-1) + new_image_tokens.append(small_img2) + image_tokens = torch.stack(new_image_tokens) + + seq = torch.cat((text_tokens,image_tokens), dim=1) + seq1 = torch.tensor([tokenizer['']]*3601, device=image_tokens.device).unsqueeze(0).expand(text_tokens.shape[0], -1) + if not self.onCUDA: + print('Converting Dsr model...') + model = self.model.cuda() + else: + model = self.model + print('Direct super-resolution...') + output_list = [] + for tim in range(max((text_tokens.shape[0]+self.max_bz-1) // self.max_bz, 1)): + output1 = filling_sequence_dsr(model, + seq[tim*self.max_bz:(tim+1)*self.max_bz], + seq1[tim*self.max_bz:(tim+1)*self.max_bz], + warmup_steps=1, block_hw=(1, 0), + strategy=self.strategy + ) + output_list.extend(output1[1:]) + if not self.onCUDA: + print('Moving back Dsr to cpu...') + model = model.cpu() + torch.cuda.empty_cache() + return torch.cat(output_list, dim=0) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_model.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d918d18922458548a2c75a77b8b6cea7e420612b --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_model.py @@ -0,0 +1,225 @@ +# -*- encoding: utf-8 -*- +''' +@File : cuda2d_model.py +@Time : 2021/10/02 01:36:32 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import torch.nn.functional as F + + +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.model.transformer import split_tensor_along_last_dim, unscaled_init_method +from SwissArmyTransformer.mpu.utils import sqrt +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 512+400) + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2]) + assert new_edge % old_edge == 0 + self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size)) + # self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + + +class AttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02) + ): + super(AttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3, + gather_output=False, init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear(hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + for layer_id in range(num_layers) + ]) + + def reinit(self, parent_model=None): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data) + self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data) + +class DsrModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.original_sequence_length = args.max_sequence_length + additional_seqlen = args.new_sequence_length - args.max_sequence_length + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + additional_seqlen, args.hidden_size + )) + self.add_mixin('attention_plus', AttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size + )) + self.layout = args.layout + # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]} + self.kernel_size = args.kernel_size + self.kernel_size2 = args.kernel_size2 + self.log_attention_weights = None + + def position_embedding_forward(self, position_ids, **kw_args): + position = position_ids[..., :self.layout[1]] + position_plus = position_ids[..., self.layout[1]:] - self.original_sequence_length + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position), + self.get_mixin('extra_position_embedding').position_embeddings(position_plus) + ), + dim=-2 + ) + return position_embeddings + + def attention_forward(self, hidden_states, mask, + layer_id=None, log_attention_weights=None, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + # attention_plus on all layers + query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id] + dense_plus = self.get_mixin('attention_plus').dense[layer_id] + # split two parts + hidden_states_plus = hidden_states[:, self.layout[1]:] + hidden_states = hidden_states[:, :self.layout[1]] + # base model qkv + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + # cuda2d model qkv + mixed_raw_layer = query_key_value_plus(hidden_states_plus) + q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3) + + dropout_fn = attn_module.attention_dropout if self.training else None + + # cuda2d attention + context_layer0, context_layer1 = sparse_attention_2d_light( + q0, k0, v0, + q1, k1, v1, + mask, + n_head=attn_module.num_attention_heads_per_partition, + text_len=self.layout[0], + kernel_size=self.kernel_size, + kernel_size2=self.kernel_size2, + attention_dropout=dropout_fn, + log_attention_weights=log_attention_weights, + add_scalar=(kw_args['add_scalar'] if 'add_scalar' in kw_args else 0) + ) + + output_0 = attn_module.dense(context_layer0) + output_1 = dense_plus(context_layer1) + output = torch.cat((output_0, output_1), dim=1) + + return output + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float()) + # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]) + return logits_parallel + + def disable_untrainable_params(self): + self.transformer.requires_grad_(False) + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations') + group.add_argument("--kernel-size", type=int, default=5) + group.add_argument("--kernel-size2", type=int, default=5) + group.add_argument("--layout", type=str, default='96,496,4096') + group.add_argument("--new-sequence-length", type=int, default=4096) + return parser + +def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, add_scalar=0, **kwargs): + ''' + q0, k0, v0: [batch_size, 1088, hidden_size] + q1, k1, v1: [batch_size, 4096, h2] + n_head: int + attention_mask: [batch_size, 1088, 1088] + ''' + from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting + + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1) + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + + if log_attention_weights is not None: + attention_scores += log_attention_weights + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + + attention_probs0 = F.softmax(attention_scores, dim=-1) + + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + # scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False) + + # cross attention + k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous() + scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field] + scores_1 = torch.cat( + ( + scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]) + add_scalar, + scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3]) + ), + dim=-1) + attention_probs1 = F.softmax(scores_1, dim=-1) + + if attention_dropout is not None: + # with get_cuda_rng_tracker().fork(): + attention_probs0 = attention_dropout(attention_probs0) + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1) + # context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True) + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False) + + context1 = context1_to_1.view(b, n_head * h, l1**2) + # weighting for cross attention + probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0) + v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0) + context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False) + context1_to_0 = context1_to_0.view(b, n_head * h, l1**2) + context1 = context1 + context1_to_0 + return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_sampling.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5f588b1bcca6e406d5ef1d7fff512c210ec136 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/dsr_sampling.py @@ -0,0 +1,204 @@ +# -*- encoding: utf-8 -*- +""" +@File : cuda2d_sampling.py +@Time : 2021/10/09 00:46:04 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +""" + +# here put the import lib +import os +import sys +import math +import random +from cv2 import reduce +import torch + +import torch +import torch.nn.functional as F +import numpy as np + + +def top_k_logits_(logits, top_k=0, filter_value=-float("Inf")): + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + return logits + + +class IterativeEntfilterStrategy: + def __init__(self, invalid_slices=[], temperature=1.0, topk=6): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.topk = topk + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + self.cluster_labels = torch.tensor( + np.load("cluster_label2.npy"), device=device, dtype=torch.long + ) + + def forward( + self, + logits_, + tokens, + temperature=None, + entfilter=None, + filter_topk=5, + temperature2=None, + ): + # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] + if temperature is None: + temperature = self.temperature + + logits = logits_.float() / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -float("Inf") + logits = logits.view(-1, logits.shape[-1]) + + rprobs = F.softmax(logits.float(), dim=-1) + c = self.cluster_labels.expand(*rprobs.shape) + cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_( + 1, c, rprobs + ) + + best_scores, best_clusters = cprobs.topk(self.topk) + bz = logits.shape[0] + best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True) + sampled_ids = torch.multinomial(best_scores, num_samples=1) + selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids) + selected_mask = ( + self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters + ) # cluster_labels [1, 20000] \in [0,500) + logits[selected_mask] = -65504 + # for i in range(bz): + # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)] + # logits[i, self.cluster_labels != selected_cluster] = -65504 + + # logits = top_k_logits(logits, self.topk, self.top_p) + probs = F.softmax( + logits.float() / 0.6, dim=-1 + ) # float is essetial, due to a bug in Pytorch + pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2]) + + assert tokens.shape[1] == pred.shape[1] + 1 + tokens = torch.cat((tokens[:, :1], pred), dim=1) + return tokens + + +def filling_sequence_dsr( + model, + seq0, + seq1, + warmup_steps=3, + block_hw=(4, 4), + strategy=IterativeEntfilterStrategy(topk=10), +): + """ + seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] + 4095 {layout[2]} final_token. + Attention: + The sampling temperature are changing, temporally we hard code them here. + The temperature in the strategy is not used. + """ + assert hasattr(model, "layout") + layout = model.layout + assert ( + len(seq0.shape) == 2 and len(seq1.shape) == 2 and seq0.shape[0] == seq1.shape[0] + ) + assert len(layout) == 3 + assert seq1.shape[1] == layout[-1] - layout[-2] + 1 + assert (seq1 >= 0).all() and (seq0 >= 0).all() + device = seq0.device + # concat and pad sequences + batch_size = seq0.shape[0] + n_pad = layout[1] - seq0.shape[1] + assert n_pad > 0, "You should truncate long input before filling." + seq = torch.cat( + ( + torch.tensor([0] * n_pad, device=device, dtype=seq0.dtype) + .unsqueeze(0) + .expand(batch_size, n_pad), + seq0, + seq1, + ), + dim=1, + ) # [b, layout[-1]+1] + assert seq.shape[1] == layout[-1] + 1 + + # build initial tokens, attention_mask, and position_ids + tokens = seq.clone() + attention_mask = torch.ones(layout[1], layout[1]).to(device) + attention_mask[: layout[0], layout[0] :] = 0 + attention_mask[n_pad:, :n_pad] = 0 + attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 + position_ids = torch.cat( + ( + torch.zeros(n_pad, dtype=torch.long), + torch.arange(0, layout[0] - n_pad), + torch.arange(513, 513 + layout[1] - layout[0]), + torch.arange(1024, 1024 + layout[2] - layout[1]), + ) + ).to(device) + log_attention_weights = torch.zeros(layout[1], layout[1], device=device).type_as( + next(model.parameters()) + ) + log_attention_weights[layout[0] :, n_pad : layout[0]] = 0.0 + + # prepare for interation + unfixed = tokens < 0 # just init an all-False tensor + unfixed[:, -layout[-1] + layout[-2] :] = True + + ll, rr = block_hw + edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4) + num_steps = warmup_steps + ll - 1 + rr + # interative refining + + # unfixed[..., -(layout[-1] - layout[-2]):].view( + # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False + + ret = [] + ret.append(tokens[:, layout[-2] + 1 :].clone()) + for step_cnt in range(1, num_steps + 1): + if step_cnt <= warmup_steps: + logits, *_dump = model( + tokens[:, :-1], + position_ids, + attention_mask, + log_attention_weights=log_attention_weights, + ) + real_temp = 1.0 + new_tokens = strategy.forward(logits, tokens, real_temp) + tokens[unfixed] = new_tokens[unfixed] + else: + logits, *_dump = model( + tokens[:, :-1], + position_ids, + attention_mask, + log_attention_weights=log_attention_weights, + ) + real_temp = 1.0 + new_tokens = strategy.forward( + logits, + tokens, + real_temp, + entfilter=1.3, + filter_topk=5, + temperature2=0.6, + ) + # tokens[unfixed] = new_tokens[unfixed] + # fixed tokens (update unfixed) + unfixed2 = tokens > 10000000 + for x in range(min(ll, step_cnt - warmup_steps)): + y = step_cnt - warmup_steps - x - 1 + if y < rr: + unfixed[..., -(layout[-1] - layout[-2]) :].view( + batch_size, edge_len // ll, ll, edge_len // rr, rr + )[:, :, x, :, y] = False + unfixed2[..., -(layout[-1] - layout[-2]) :].view( + batch_size, edge_len // ll, ll, edge_len // rr, rr + )[:, :, x, :, y] = True + tokens[unfixed2] = new_tokens[unfixed2] + + ret.append(tokens[:, layout[-2] + 1 :].clone()) + + return ret diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/iterative_sr.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/iterative_sr.py new file mode 100644 index 0000000000000000000000000000000000000000..2608f6ac5e164f4b0356eec1837a2b37c109c25c --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/iterative_sr.py @@ -0,0 +1,118 @@ +# -*- encoding: utf-8 -*- +''' +@File : iterative_sr.py +@Time : 2022/03/02 15:57:45 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +# here put the import lib +import os +import sys +import math +import random +from PIL import ImageEnhance, Image + +import torch +import argparse +from torchvision import transforms + +from SwissArmyTransformer.training.model_io import load_checkpoint +from SwissArmyTransformer import get_args +from .itersr_sampling import filling_sequence_itersr, IterativeEntfilterStrategy +from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually + +from .itersr_model import ItersrModel + +from videogen_hub.depend.icetk import icetk as tokenizer + +class IterativeSuperResolution: + def __init__(self, args, path, max_bz=4, shared_transformer=None): + args.load = path + args.kernel_size = 5 + args.kernel_size2 = 5 + args.new_sequence_length = 4624 + args.layout = [16,3616] + + model = ItersrModel(args, transformer=shared_transformer) + if args.fp16: + model = model.half() + + load_checkpoint(model, args) # on cpu + model.eval() + self.model = model.cuda() + + # save cpu weights + self.saved_weights = dict((k,v.cpu()) + for k, v in model.named_parameters() + if 'transformer' in k + ) + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + + self.strategy = IterativeEntfilterStrategy(invalid_slices, + temperature=args.temp_all_itersr, topk=args.topk_itersr) + self.max_bz = max_bz + + def _restore_transformer_from_cpu(self, non_blocking=False): + for k, v in self.model.named_parameters(): + if k in self.saved_weights: + v.copy_(self.saved_weights[k]) + + def __call__(self, text_tokens, image_tokens, enhance=False, input_mask=None): + if len(text_tokens.shape) == 1: + text_tokens.unsqueeze_(0) + text_tokens = text_tokens.clone()[..., :16] + if len(image_tokens.shape) == 1: + image_tokens.unsqueeze_(0) + if enhance: + new_image_tokens = [] + for big_img in image_tokens: + decoded = tokenizer.decode(image_ids=big_img).squeeze(0) + ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr)) + big_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1) + new_image_tokens.append(big_img2) + image_tokens = torch.stack(new_image_tokens) + print('Converting Itersr model...') + self._restore_transformer_from_cpu() + model = self.model + print('iterative super-resolution...') + output_list = [] + for tim in range(max(text_tokens.shape[0] // self.max_bz, 1)): + big_img = image_tokens[tim*self.max_bz:(tim+1)*self.max_bz] + text_seq = text_tokens[tim*self.max_bz:(tim+1)*self.max_bz] + mask_raw = torch.tensor( + [ + -1, 0, 1, 2, 3, 4, + 0, -1, 2, -1, -2, 5, + 1, -2, 3, 4, 5, 6, + 2, 3, 4, 5, -1, 1, + 3, -1, -2, 0, -1, 2, + 4, 5, 6, 1, 3, -2 + ] + ).view(1, 6, 1, 6).expand(10, 6, 10, 6).reshape(-1).contiguous() + + topks = [60, 40, 40, 40, 20, 20, 10] + + for mask_ratio in range(1, 7): + self.strategy.topk = topks[mask_ratio] + mask = (mask_raw.to(big_img.device) >= mask_ratio) + if input_mask is not None: + mask = mask & input_mask + big_img.masked_fill_(mask, tokenizer['']) + seq1 = big_img + output1 = filling_sequence_itersr(model, text_seq, seq1, + warmup_steps=1, block_hw=(1, 0), + strategy=self.strategy + ) + big_img = output1 + print(f'Iter {mask_ratio} times.') + output_list.append(output1.clone()) + return torch.cat(output_list, dim=0) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_model.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..40981bcbae1b18381fc9172b9702c9d0905f6524 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_model.py @@ -0,0 +1,232 @@ +# -*- encoding: utf-8 -*- +''' +@File : itersr_model.py +@Time : 2021/10/02 01:36:32 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import torch.nn.functional as F + + +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.mpu.utils import sqrt +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear +from SwissArmyTransformer.model.transformer import unscaled_init_method, split_tensor_along_last_dim + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 512+400) + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2]) + assert new_edge % old_edge == 0 + self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size)) + +class ItersrModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.original_sequence_length = args.max_sequence_length + additional_seqlen = args.new_sequence_length - args.max_sequence_length + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + additional_seqlen, args.hidden_size + )) + # self.add_mixin('attention_plus', AttentionMixin( + # num_layers=args.num_layers, + # hidden_size=args.hidden_size + # )) + self.layout = args.layout + # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]} + self.kernel_size = args.kernel_size + self.kernel_size2 = args.kernel_size2 + self.log_attention_weights = None + + def position_embedding_forward(self, position_ids, **kw_args): + position = position_ids[..., :self.layout[0]] + position_plus = position_ids[..., self.layout[0]:] - self.original_sequence_length + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position), + self.get_mixin('extra_position_embedding').position_embeddings(position_plus) + ), + dim=-2 + ) + return position_embeddings + + def attention_forward(self, hidden_states, mask, + layer_id=None, log_attention_weights=None, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + # base model qkv + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer[:, :self.layout[0]], 3) + # cuda2d model qkv + q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer[:, self.layout[0]:], 3) + + dropout_fn = attn_module.attention_dropout if self.training else None + + # cuda2d attention + context_layer = sparse_attention_2d_text( + q0, k0, v0, + q1, k1, v1, + mask, + n_head=attn_module.num_attention_heads_per_partition, + text_len=self.layout[0], + kernel_size=self.kernel_size, + attention_dropout=dropout_fn, + log_attention_weights=log_attention_weights, + ) + + output = attn_module.dense(context_layer) + + return output + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]).float() + # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]) + return logits_parallel + + # def disable_untrainable_params(self): + # self.transformer.requires_grad_(False) + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations') + group.add_argument("--kernel-size", type=int, default=5) + group.add_argument("--kernel-size2", type=int, default=5) + group.add_argument("--layout", type=str, default='16,3616') + group.add_argument("--new-sequence-length", type=int, default=4096) + return parser + +def sparse_attention_2d_text(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs): + ''' + q0, k0, v0: [batch_size, 16, hidden_size] + q1, k1, v1: [batch_size, 3600, hidden_size] + n_head: int + attention_mask: [batch_size, 16] + ''' + from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + h, l1 = h0 // n_head, sqrt(s1) + assert attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}" + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + + attention_probs0 = F.softmax(attention_scores, dim=-1) + + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False) + + # cross attention + scores_1_to_0 = torch.matmul(q1.view(b, n_head, h, s1).transpose(-1, -2), k0T) + if log_attention_weights is not None: + scores_1_to_0 += log_attention_weights + scores_1_to_0 = torch.mul(scores_1_to_0, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + scores_1 = torch.cat( + ( + scores_1_to_0.view(b*n_head, s1, s0), + scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3]) + ), + dim=-1) + attention_probs1 = F.softmax(scores_1, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1) + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False) + + context1 = context1_to_1.view(b, n_head, h, l1**2) + # weighting for cross attention + probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view(b, n_head, -1, scores_1_to_0.shape[3]) + + context1_to_0 = torch.matmul(probs_1_to_0, v0) + context1 = context1.transpose(-1, -2) + context1_to_0 + + output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0) + + return output + +def sparse_attention_2d_notext(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs): + ''' + q0, k0, v0: [batch_size, 16, hidden_size] + q1, k1, v1: [batch_size, 3600, hidden_size] + n_head: int + attention_mask: [batch_size, 16] + ''' + from SwissArmyTransformer.mpu.local_attention_function import f_similar, f_weighting + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + h, l1 = h0 // n_head, sqrt(s1) + assert len(attention_mask.shape) == 4 and attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}" + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + + attention_probs0 = F.softmax(attention_scores, dim=-1) + + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False) + + attention_probs1 = F.softmax(scores_1_to_1, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1 + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False) + + context1 = context1_to_1.view(b, n_head, h, l1**2) + # weighting for cross attention + context1 = context1.transpose(-1, -2) + + output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0) + + return output \ No newline at end of file diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_sampling.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e9583b56e6638cd009f05b89087348045f05ca --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/itersr_sampling.py @@ -0,0 +1,168 @@ +# -*- encoding: utf-8 -*- +''' +@File : itersr_sampling.py +@Time : 2022/03/03 14:24:28 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import numpy as np + +import torch +import torch.nn.functional as F +from videogen_hub.depend.icetk import icetk as tokenizer + +def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + return logits + +# class IterativeEntfilterStrategy: +# def __init__(self, invalid_slices=[], temperature=1., topk=10): +# self.invalid_slices = invalid_slices +# self.temperature = temperature +# self.topk = topk +# self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long) + + +# def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): +# # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] +# if temperature is None: +# temperature = self.temperature + +# logits = logits_.float() / temperature +# for invalid_slice in self.invalid_slices: +# logits[..., invalid_slice] = -float('Inf') +# logits = logits.view(-1, logits.shape[-1]) + +# rprobs = F.softmax(logits.float(), dim=-1) +# c = self.cluster_labels.expand(*rprobs.shape) +# cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs) + +# best_scores, best_clusters = cprobs.topk(self.topk) +# bz = logits.shape[0] +# best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True) +# sampled_ids = torch.multinomial(best_scores, num_samples=1) +# selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids) +# selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500) +# logits[selected_mask] = -65504 +# # for i in range(bz): +# # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)] +# # logits[i, self.cluster_labels != selected_cluster] = -65504 + +# # logits = top_k_logits(logits, self.topk, self.top_p) +# probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch +# pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2]) + +# assert tokens.shape[1] == pred.shape[1] +# tokens = pred +# return tokens + +class IterativeEntfilterStrategy: + def __init__(self, invalid_slices=[], temperature=1., topk=10): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.topk = topk + + def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): + # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] + if temperature is None: + temperature = self.temperature + # check entropy filter + # if entfilter is not None: + # assert temperature2 is not None + # topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1) + # ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length] + # temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2 + + logits = logits.float() / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -float('Inf') + + # debiased topk + # probs = F.softmax(logits, dim=-1) + # tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1) + # pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) + # edge_idx = tk_idx[:, :, -1:] + # edge_value = tk_value[:, :, -1:] + # edge_mask = probs.gather(dim=-1, index=pred) < edge_value + # pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token + # pred.squeeze_(-1) # [batch_size, seq_length] + + top_k_logits_(logits, self.topk) + probs = F.softmax(logits, dim=-1) + pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) + pred.squeeze_(-1) + + assert tokens.shape[1] == pred.shape[1] + tokens = pred + return tokens + +def filling_sequence_itersr( + model, + seq0, + seq1, + warmup_steps=3, + block_hw=(4, 4), + strategy=IterativeEntfilterStrategy(topk=10), + ): + ''' + seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] + 4095 {layout[2]} final_token. + Attention: + The sampling temperature are changing, temporally we hard code them here. + The temperature in the strategy is not used. + ''' + assert hasattr(model, 'layout') + layout = model.layout + + device = seq0.device + # concat and pad sequences + batch_size = seq0.shape[0] + n_pad = layout[0] - seq0.shape[1] + assert n_pad >= 0, "You should truncate long input before filling." + seq = torch.cat(( + torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype) + .unsqueeze(0).expand(batch_size, n_pad), + seq0, seq1), dim=1) # [b, layout[-1]+1] + assert seq.shape[1] == layout[-1] + + # build initial tokens, attention_mask, and position_ids + tokens = seq.clone() + attention_mask = torch.ones(layout[0]).to(device) + attention_mask[:n_pad] = 0 + attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16 + position_ids = torch.cat(( + torch.zeros(n_pad, dtype=torch.long), + torch.arange(0, layout[0] - n_pad), + torch.arange(1024, 1024+layout[1]-layout[0]))).to(device) + log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters())) + log_attention_weights[n_pad:layout[0]] = 0. + log_attention_weights = log_attention_weights.unsqueeze(0) + + # prepare for interation + unfixed = (tokens == tokenizer['']) + ll, rr = block_hw + edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4) + num_steps = 1 + # interative refining + + # unfixed[..., -(layout[-1] - layout[-2]):].view( + # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False + + + ret = [] + # ret.append(tokens[:, layout[-2]:-1].clone()) + for step_cnt in range(1, num_steps+1): + logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights) + real_temp = 1. + new_tokens = strategy.forward(logits, tokens, real_temp) + tokens[unfixed] = new_tokens[unfixed] + + ret.append(tokens[:, layout[-2]:].clone()) + return torch.cat(ret, dim=0) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/sr_group.py b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/sr_group.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec51b67fce1551d29d70018551226725262ae20 --- /dev/null +++ b/src/videogen_hub/pipelines/cogvideo/cogvideo_src/sr_pipeline/sr_group.py @@ -0,0 +1,49 @@ +# -*- encoding: utf-8 -*- +''' +@File : sr_group.py +@Time : 2022/04/02 01:17:21 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +import numpy as np +import torch +import torch.nn.functional as F +from SwissArmyTransformer.resources import auto_create +from .direct_sr import DirectSuperResolution +from .iterative_sr import IterativeSuperResolution + +class SRGroup: + def __init__(self, args, home_path=None,): + dsr_path = auto_create('cogview2-dsr', path=home_path) + itersr_path = auto_create('cogview2-itersr', path=home_path) + dsr = DirectSuperResolution(args, dsr_path) + itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer) + self.dsr = dsr + self.itersr = itersr + + def sr_base(self, img_tokens, txt_tokens): + assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2 + batch_size = img_tokens.shape[0] + txt_len = txt_tokens.shape[-1] + if len(txt_tokens.shape) == 1: + txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) + sred_tokens = self.dsr(txt_tokens, img_tokens) + iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone()) + return iter_tokens[-batch_size:] + + # def sr_patch(self, img_tokens, txt_tokens): + # assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2 + # batch_size = img_tokens.shape[0] * 9 + # txt_len = txt_tokens.shape[-1] + # if len(txt_tokens.shape) == 1: + # txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) + # img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400) + # iter_tokens = self.sr_base(img_tokens, txt_tokens) + # return iter_tokens \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/LICENSE b/src/videogen_hub/pipelines/consisti2v/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..25a1149b7da8c528c24a2a411c49248082a117c9 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 TIGER Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/videogen_hub/pipelines/consisti2v/__init__.py b/src/videogen_hub/pipelines/consisti2v/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/configs/__init__.py b/src/videogen_hub/pipelines/consisti2v/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/configs/inference/__init__.py b/src/videogen_hub/pipelines/consisti2v/configs/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/configs/inference/inference.yaml b/src/videogen_hub/pipelines/consisti2v/configs/inference/inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62a46d68fcb745850d1864bc4c02c2b94e0dff6a --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/configs/inference/inference.yaml @@ -0,0 +1,48 @@ +output_dir: "samples/inference" +output_name: "i2v" + +pretrained_model_path: "TIGER-Lab/ConsistI2V" +unet_path: null +unet_ckpt_prefix: "module." +pipeline_pretrained_path: null + +sampling_kwargs: + height: 256 + width: 256 + n_frames: 16 + steps: 50 + ddim_eta: 0.0 + guidance_scale_txt: 7.5 + guidance_scale_img: 1.0 + guidance_rescale: 0.0 + num_videos_per_prompt: 1 + frame_stride: 3 + +unet_additional_kwargs: + variant: null + n_temp_heads: 8 + augment_temporal_attention: true + temp_pos_embedding: "rotary" # "rotary" or "sinusoidal" + first_frame_condition_mode: "concat" + use_frame_stride_condition: true + noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive" + noise_alpha: 1.0 + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: false + rescale_betas_zero_snr: false # true if using zero terminal snr + timestep_spacing: "leading" # "trailing" if using zero terminal snr + prediction_type: "epsilon" # "v_prediction" if using zero terminal snr + +frameinit_kwargs: + enable: true + camera_motion: null + noise_level: 850 + filter_params: + method: 'gaussian' + d_s: 0.25 + d_t: 0.25 \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/configs/inference/inference_autoregress.yaml b/src/videogen_hub/pipelines/consisti2v/configs/inference/inference_autoregress.yaml new file mode 100644 index 0000000000000000000000000000000000000000..acc730e9f155920aca557ee89a61767b361a14e3 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/configs/inference/inference_autoregress.yaml @@ -0,0 +1,49 @@ +output_dir: "samples/inference" +output_name: "long_video" + +pretrained_model_path: "TIGER-Lab/ConsistI2V" +unet_path: null +unet_ckpt_prefix: "module." +pipeline_pretrained_path: null + +sampling_kwargs: + height: 256 + width: 256 + n_frames: 16 + steps: 50 + ddim_eta: 0.0 + guidance_scale_txt: 7.5 + guidance_scale_img: 1.0 + guidance_rescale: 0.0 + num_videos_per_prompt: 1 + frame_stride: 3 + autoregress_steps: 3 + +unet_additional_kwargs: + variant: null + n_temp_heads: 8 + augment_temporal_attention: true + temp_pos_embedding: "rotary" # "rotary" or "sinusoidal" + first_frame_condition_mode: "concat" + use_frame_stride_condition: true + noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive" + noise_alpha: 1.0 + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: false + rescale_betas_zero_snr: false # true if using zero terminal snr + timestep_spacing: "leading" # "trailing" if using zero terminal snr + prediction_type: "epsilon" # "v_prediction" if using zero terminal snr + + +frameinit_kwargs: + enable: true + noise_level: 850 + filter_params: + method: 'gaussian' + d_s: 0.25 + d_t: 0.25 \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/configs/prompts/__init__.py b/src/videogen_hub/pipelines/consisti2v/configs/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/configs/prompts/default.yaml b/src/videogen_hub/pipelines/consisti2v/configs/prompts/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..796e05e332253c98c63a2563872005950ffd2ae9 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/configs/prompts/default.yaml @@ -0,0 +1,16 @@ +seeds: random + +prompts: + - "timelapse at the snow land with aurora in the sky." + - "fireworks." + - "clown fish swimming through the coral reef." + - "melting ice cream dripping down the cone." + +n_prompts: + - "" + +path_to_first_frames: + - "assets/example/example_01.png" + - "assets/example/example_02.png" + - "assets/example/example_03.png" + - "assets/example/example_04.png" \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/configs/training/__init__.py b/src/videogen_hub/pipelines/consisti2v/configs/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/configs/training/training.yaml b/src/videogen_hub/pipelines/consisti2v/configs/training/training.yaml new file mode 100644 index 0000000000000000000000000000000000000000..942f29d61a615fd050b1900c56ea949d2f80f2c8 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/configs/training/training.yaml @@ -0,0 +1,92 @@ +output_dir: "checkpoints" +pretrained_model_path: "stabilityai/stable-diffusion-2-1-base" + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: false + rescale_betas_zero_snr: false # true if using zero terminal snr + timestep_spacing: "leading" # "trailing" if using zero terminal snr + prediction_type: "epsilon" # "v_prediction" if using zero terminal snr + +train_data: + dataset: "joint" + pexels_config: + enable: false + json_path: null + caption_json_path: null + video_folder: null + webvid_config: + enable: true + json_path: "/path/to/webvid/annotation" + video_folder: "/path/to/webvid/data" + sample_size: 256 + sample_duration: null + sample_fps: null + sample_stride: [1, 5] + sample_n_frames: 16 + +validation_data: + prompts: + - "timelapse at the snow land with aurora in the sky." + - "fireworks." + - "clown fish swimming through the coral reef." + - "melting ice cream dripping down the cone." + + path_to_first_frames: + - "assets/example/example_01.jpg" + - "assets/example/example_02.jpg" + - "assets/example/example_03.jpg" + - "assets/example/example_04.jpg" + + num_inference_steps: 50 + ddim_eta: 0.0 + guidance_scale_txt: 7.5 + guidance_scale_img: 1.0 + guidance_rescale: 0.0 + frame_stride: 3 + +trainable_modules: + - "all" + # - "conv3ds." + # - "tempo_attns." + +resume_from_checkpoint: null + +unet_additional_kwargs: + variant: null + n_temp_heads: 8 + augment_temporal_attention: true + temp_pos_embedding: "rotary" # "rotary" or "sinusoidal" + first_frame_condition_mode: "concat" + use_frame_stride_condition: true + noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive" + noise_alpha: 1.0 + +cfg_random_null_text_ratio: 0.1 +cfg_random_null_img_ratio: 0.1 + +use_ema: false +ema_decay: 0.9999 + +learning_rate: 5.e-5 +train_batch_size: 3 +gradient_accumulation_steps: 1 +max_grad_norm: 0.5 + +max_train_epoch: -1 +max_train_steps: 200000 +checkpointing_epochs: -1 +checkpointing_steps: 2000 +validation_steps: 1000 + +seed: 42 +mixed_precision: "bf16" +num_workers: 32 +enable_xformers_memory_efficient_attention: true + +is_image: false +is_debug: false diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/__init__.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/data/__init__.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/data/dataset.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d59efde71c74ae0dd15b449bfba188d685e2ed --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/data/dataset.py @@ -0,0 +1,315 @@ +import os, io, csv, math, random +import json +import numpy as np +from einops import rearrange +from decord import VideoReader + +import torch +import torchvision.transforms as transforms +from torch.utils.data.dataset import Dataset + +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + +class WebVid10M(Dataset): + def __init__( + self, + json_path, video_folder=None, + sample_size=256, sample_stride=4, sample_n_frames=16, + is_image=False, + **kwargs, + ): + logger.info(f"loading annotations from {json_path} ...") + with open(json_path, 'rb') as json_file: + json_list = list(json_file) + self.dataset = [json.loads(json_str) for json_str in json_list] + self.length = len(self.dataset) + logger.info(f"data scale: {self.length}") + + self.video_folder = video_folder + self.sample_stride = sample_stride if isinstance(sample_stride, int) else tuple(sample_stride) + self.sample_n_frames = sample_n_frames + self.is_image = is_image + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.Resize(sample_size[0], antialias=None), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_relative_path, name = video_dict['file'], video_dict['text'] + + if self.video_folder is not None: + if video_relative_path[0] == '/': + video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path)) + else: + video_dir = os.path.join(self.video_folder, video_relative_path) + else: + video_dir = video_relative_path + video_reader = VideoReader(video_dir) + video_length = len(video_reader) + + if not self.is_image: + if isinstance(self.sample_stride, int): + stride = self.sample_stride + elif isinstance(self.sample_stride, tuple): + stride = random.randint(self.sample_stride[0], self.sample_stride[1]) + clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + else: + frame_difference = random.randint(2, self.sample_n_frames) + clip_length = min(video_length, (frame_difference - 1) * self.sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = [start_idx, start_idx + clip_length - 1] + + pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + + except Exception as e: + idx = random.randint(0, self.length-1) + + pixel_values = self.pixel_transforms(pixel_values) + sample = dict(pixel_values=pixel_values, text=name) + return sample + + +class Pexels(Dataset): + def __init__( + self, + json_path, caption_json_path, video_folder=None, + sample_size=256, sample_duration=1, sample_fps=8, + is_image=False, + **kwargs, + ): + logger.info(f"loading captions from {caption_json_path} ...") + with open(caption_json_path, 'rb') as caption_json_file: + caption_json_list = list(caption_json_file) + self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list} + + logger.info(f"loading annotations from {json_path} ...") + with open(json_path, 'rb') as json_file: + json_list = list(json_file) + dataset = [json.loads(json_str) for json_str in json_list] + + self.dataset = [] + for data in dataset: + data['text'] = self.caption_dict[data['id']] + if data['height'] / data['width'] < 0.625: + self.dataset.append(data) + self.length = len(self.dataset) + logger.info(f"data scale: {self.length}") + + self.video_folder = video_folder + self.sample_duration = sample_duration + self.sample_fps = sample_fps + self.sample_n_frames = sample_duration * sample_fps + self.is_image = is_image + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.Resize(sample_size[0], antialias=None), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_relative_path, name = video_dict['file'], video_dict['text'] + fps = video_dict['fps'] + + if self.video_folder is not None: + if video_relative_path[0] == '/': + video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path)) + else: + video_dir = os.path.join(self.video_folder, video_relative_path) + else: + video_dir = video_relative_path + video_reader = VideoReader(video_dir) + video_length = len(video_reader) + + if not self.is_image: + clip_length = min(video_length, math.ceil(fps * self.sample_duration)) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + else: + frame_difference = random.randint(2, self.sample_n_frames) + sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1) + clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = [start_idx, start_idx + clip_length - 1] + + pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + + except Exception as e: + idx = random.randint(0, self.length-1) + + pixel_values = self.pixel_transforms(pixel_values) + sample = dict(pixel_values=pixel_values, text=name) + return sample + + +class JointDataset(Dataset): + def __init__( + self, + webvid_config, pexels_config, + sample_size=256, + sample_duration=None, sample_fps=None, sample_stride=None, sample_n_frames=None, + is_image=False, + **kwargs, + ): + assert (sample_duration is None and sample_fps is None) or (sample_duration is not None and sample_fps is not None), "sample_duration and sample_fps should be both None or not None" + if sample_duration is not None and sample_fps is not None: + assert sample_stride is None, "when sample_duration and sample_fps are not None, sample_stride should be None" + if sample_stride is not None: + assert sample_fps is None and sample_duration is None, "when sample_stride is not None, sample_duration and sample_fps should be both None" + + self.dataset = [] + + if pexels_config.enable: + logger.info(f"loading pexels dataset") + logger.info(f"loading captions from {pexels_config.caption_json_path} ...") + with open(pexels_config.caption_json_path, 'rb') as caption_json_file: + caption_json_list = list(caption_json_file) + self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list} + + logger.info(f"loading annotations from {pexels_config.json_path} ...") + with open(pexels_config.json_path, 'rb') as json_file: + json_list = list(json_file) + dataset = [json.loads(json_str) for json_str in json_list] + + for data in dataset: + data['text'] = self.caption_dict[data['id']] + data['dataset'] = 'pexels' + if data['height'] / data['width'] < 0.625: + self.dataset.append(data) + + if webvid_config.enable: + logger.info(f"loading webvid dataset") + logger.info(f"loading annotations from {webvid_config.json_path} ...") + with open(webvid_config.json_path, 'rb') as json_file: + json_list = list(json_file) + dataset = [json.loads(json_str) for json_str in json_list] + for data in dataset: + data['dataset'] = 'webvid' + self.dataset.extend(dataset) + + self.length = len(self.dataset) + logger.info(f"data scale: {self.length}") + + self.pexels_folder = pexels_config.video_folder + self.webvid_folder = webvid_config.video_folder + self.sample_duration = sample_duration + self.sample_fps = sample_fps + self.sample_n_frames = sample_duration * sample_fps if sample_n_frames is None else sample_n_frames + self.sample_stride = sample_stride if (sample_stride is None) or (sample_stride is not None and isinstance(sample_stride, int)) else tuple(sample_stride) + self.is_image = is_image + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.Resize(sample_size[0], antialias=None), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_relative_path, name = video_dict['file'], video_dict['text'] + + if video_dict['dataset'] == 'pexels': + video_folder = self.pexels_folder + elif video_dict['dataset'] == 'webvid': + video_folder = self.webvid_folder + else: + raise NotImplementedError + + if video_folder is not None: + if video_relative_path[0] == '/': + video_dir = os.path.join(video_folder, os.path.basename(video_relative_path)) + else: + video_dir = os.path.join(video_folder, video_relative_path) + else: + video_dir = video_relative_path + video_reader = VideoReader(video_dir) + video_length = len(video_reader) + + stride = None + if not self.is_image: + if self.sample_duration is not None: + fps = video_dict['fps'] + clip_length = min(video_length, math.ceil(fps * self.sample_duration)) + elif self.sample_stride is not None: + if isinstance(self.sample_stride, int): + stride = self.sample_stride + elif isinstance(self.sample_stride, tuple): + stride = random.randint(self.sample_stride[0], self.sample_stride[1]) + clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1) + + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + + else: + frame_difference = random.randint(2, self.sample_n_frames) + if self.sample_duration is not None: + fps = video_dict['fps'] + sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1) + elif self.sample_stride is not None: + sample_stride = self.sample_stride + + clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = [start_idx, start_idx + clip_length - 1] + + pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + + return pixel_values, name, stride + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name, stride = self.get_batch(idx) + break + + except Exception as e: + idx = random.randint(0, self.length-1) + + pixel_values = self.pixel_transforms(pixel_values) + sample = dict(pixel_values=pixel_values, text=name, stride=stride) + return sample diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/models/__init__.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/models/rotary_embedding.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/rotary_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..95d5d1576f57907724324a0b8bf70c91d1667bf3 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/rotary_embedding.py @@ -0,0 +1,280 @@ +from math import pi, log + +import torch +from torch.nn import Module, ModuleList +from torch.cuda.amp import autocast +from torch import nn, einsum, broadcast_tensors, Tensor + +from einops import rearrange, repeat + +from beartype import beartype +from beartype.typing import Literal, Union, Optional + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# broadcat, as tortoise-tts was using it + +def broadcat(tensors, dim = -1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim = dim) + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + +@autocast(enabled = False) +def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2): + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:].to(t) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + return torch.cat((t_left, t, t_right), dim = -1) + +# learned rotation helpers + +def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): + if exists(freq_ranges): + rotations = einsum('..., f -> ... f', rotations, freq_ranges) + rotations = rearrange(rotations, '... r f -> ... (r f)') + + rotations = repeat(rotations, '... n -> ... (n r)', r = 2) + return apply_rotary_emb(rotations, t, start_index = start_index) + +# classes + +class RotaryEmbedding(Module): + @beartype + def __init__( + self, + dim, + custom_freqs: Optional[Tensor] = None, + freqs_for: Union[ + Literal['lang'], + Literal['pixel'], + Literal['constant'] + ] = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.tmp_store('cached_freqs', None) + self.tmp_store('cached_scales', None) + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.tmp_store('dummy', torch.tensor(0)) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + if not use_xpos: + self.tmp_store('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.tmp_store('scale', scale) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent = False) + + def get_seq_pos(self, seq_len, device, dtype, offset = 0): + return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None, seq_pos = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + if exists(freq_seq_len): + assert freq_seq_len >= seq_len + seq_len = freq_seq_len + + if seq_pos is None: + seq_pos = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) + else: + assert seq_pos.shape[0] == seq_len + + freqs = self.forward(seq_pos, seq_len = seq_len, offset = offset) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return apply_rotary_emb(freqs, t, seq_dim = seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): + seq_dim = default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len) + rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) + + freqs = self.forward(seq, seq_len = seq_len) + scale = self.get_scale(seq, seq_len = seq_len).to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + scale = rearrange(scale, 'n d -> n 1 d') + + rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + @beartype + def get_scale( + self, + t: Tensor, + seq_len: Optional[int] = None, + offset = 0 + ): + assert self.use_xpos + + should_cache = exists(seq_len) + + if ( + should_cache and \ + exists(self.cached_scales) and \ + (seq_len + offset) <= self.cached_scales.shape[0] + ): + return self.cached_scales[offset:(offset + seq_len)] + + scale = 1. + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + if should_cache: + self.tmp_store('cached_scales', scale) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + @autocast(enabled = False) + def forward( + self, + t: Tensor, + seq_len = None, + offset = 0 + ): + # should_cache = ( + # not self.learned_freq and \ + # exists(seq_len) and \ + # self.freqs_for != 'pixel' + # ) + + # if ( + # should_cache and \ + # exists(self.cached_freqs) and \ + # (offset + seq_len) <= self.cached_freqs.shape[0] + # ): + # return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + # if should_cache: + # self.tmp_store('cached_freqs', freqs.detach()) + + return freqs diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_attention.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..061375a32050bff7080cdcde3934cca5b015be3a --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_attention.py @@ -0,0 +1,809 @@ +from importlib import import_module +from typing import Callable, Optional, Union +import math + +from einops import rearrange, repeat + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate, logging +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer +from diffusers.models.attention_processor import ( + Attention, + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, + AttnProcessor2_0, + SpatialNorm, + LORA_ATTENTION_PROCESSORS, + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + SlicedAttnProcessor, + AttentionProcessor +) + +from .rotary_embedding import RotaryEmbedding + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +@maybe_allow_in_graph +class ConditionalAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block=False, + processor: Optional["AttnProcessor"] = None, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ): + is_lora = hasattr(self, "processor") and isinstance( + self.processor, + LORA_ATTENTION_PROCESSORS, + ) + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and (is_lora or is_custom_diffusion): + raise NotImplementedError( + f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_lora: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor"): + if ( + hasattr(self, "processor") + and not isinstance(processor, LORA_ATTENTION_PROCESSORS) + and self.to_q.lora_layer is not None + ): + deprecate( + "set_processor to offload LoRA", + "0.26.0", + "In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", + ) + # (Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete + # We need to remove all LoRA layers + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.prcoessor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) + lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor, out_dim=3): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores(self, query, key, attention_mask=None): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): + if batch_size is None: + deprecate( + "batch_size=None", + "0.22.0", + ( + "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" + " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" + " `prepare_attention_mask` when preparing the attention_mask." + ), + ) + batch_size = 1 + + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states): + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + +class TemporalConditionalAttention(Attention): + def __init__(self, n_frames=8, rotary_emb=False, *args, **kwargs): + super().__init__(processor=RotaryEmbAttnProcessor2_0() if rotary_emb else None, *args, **kwargs) + + if not rotary_emb: + self.pos_enc = PositionalEncoding(self.inner_dim) + else: + rotary_bias = RelativePositionBias(heads=kwargs['heads'], max_distance=32) + self.rotary_bias = rotary_bias + self.rotary_emb = RotaryEmbedding(self.inner_dim // 2) + + self.use_rotary_emb = rotary_emb + self.n_frames = n_frames + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + adjacent_slices=None, + **cross_attention_kwargs): + + key_pos_idx = None + + bt, hw, c = hidden_states.shape + hidden_states = rearrange(hidden_states, '(b t) hw c -> b hw t c', t=self.n_frames) + if not self.use_rotary_emb: + pos_embed = self.pos_enc(self.n_frames) + hidden_states = hidden_states + pos_embed + hidden_states = rearrange(hidden_states, 'b hw t c -> (b hw) t c') + + if encoder_hidden_states is not None: + assert adjacent_slices is None + encoder_hidden_states = encoder_hidden_states[::self.n_frames] + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b hw) n c', hw=hw) + + if adjacent_slices is not None: + assert encoder_hidden_states is None + adjacent_slices = rearrange(adjacent_slices, 'b c h w n -> b (h w) n c') + if not self.use_rotary_emb: + first_frame_pos_embed = pos_embed[0:1, :] + adjacent_slices = adjacent_slices + first_frame_pos_embed + else: + pos_idx = torch.arange(self.n_frames, device=hidden_states.device, dtype=hidden_states.dtype) + first_frame_pos_pad = torch.zeros(adjacent_slices.shape[2], device=hidden_states.device, dtype=hidden_states.dtype) + key_pos_idx = torch.cat([pos_idx, first_frame_pos_pad], dim=0) + adjacent_slices = rearrange(adjacent_slices, 'b hw n c -> (b hw) n c') + encoder_hidden_states = torch.cat([hidden_states, adjacent_slices], dim=1) + + if not self.use_rotary_emb: + out = self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + else: + out = self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + key_pos_idx=key_pos_idx, + **cross_attention_kwargs, + ) + + out = rearrange(out, '(b hw) t c -> (b t) hw c', hw=hw) + + return out + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers, attention_op=None): + if use_memory_efficient_attention_xformers: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + self.set_processor(processor) + + +class PositionalEncoding(nn.Module): + def __init__(self, dim, max_pos=512): + super().__init__() + + pos = torch.arange(max_pos) + + freq = torch.arange(dim//2) / dim + freq = (freq * torch.tensor(10000).log()).exp() + + x = rearrange(pos, 'L -> L 1') / freq + x = rearrange(x, 'L d -> L d 1') + + pe = torch.cat((x.sin(), x.cos()), dim=-1) + self.pe = rearrange(pe, 'L d sc -> L (d sc)') + + self.dummy = nn.Parameter(torch.rand(1)) + + def forward(self, length): + enc = self.pe[:length] + enc = enc.to(self.dummy.device, self.dummy.dtype) + return enc + + +# code taken from https://github.com/Vchitect/LaVie/blob/main/base/models/temporal_attention.py +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, qlen, klen, device, dtype): + q_pos = torch.arange(qlen, dtype = torch.long, device = device) + k_pos = torch.arange(klen, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + values = values.to(device, dtype) + return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames + + +class RotaryEmbAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + Add rotary embedding support + """ + + def __init__(self): + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale: float = 1.0, + key_pos_idx: Optional[torch.Tensor] = None, + ): + assert attention_mask is None + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + # if attention_mask is not None: + # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # # scaled_dot_product_attention expects attention_mask shape to be + # # (batch, heads, source_length, target_length) + # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + qlen = hidden_states.shape[1] + klen = encoder_hidden_states.shape[1] + # currently only add bias for self attention. Relative distance doesn't make sense for cross attention. + # if qlen == klen: + # time_rel_pos_bias = attn.rotary_bias(qlen, klen, device=hidden_states.device, dtype=hidden_states.dtype) + # attention_mask = repeat(time_rel_pos_bias, "h d1 d2 -> b h d1 d2", b=batch_size) + + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) + + query = attn.rotary_emb.rotate_queries_or_keys(query) + if qlen == klen: + key = attn.rotary_emb.rotate_queries_or_keys(key) + elif key_pos_idx is not None: + key = attn.rotary_emb.rotate_queries_or_keys(key, seq_pos=key_pos_idx) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_transformer_blocks.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_transformer_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2f39109946d4fba8305b150d6e105c852acd785c --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_transformer_blocks.py @@ -0,0 +1,564 @@ +# Modified from https://github.com/huggingface/diffusers/blob/v0.21.0/src/diffusers/models/transformer_2d.py +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from einops import rearrange, repeat + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from diffusers.models.attention import AdaLayerNorm, AdaLayerNormZero, FeedForward, GatedSelfAttentionDense +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.transformer_2d import Transformer2DModelOutput +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear + +from .videoldm_attention import ConditionalAttention, TemporalConditionalAttention + + +class Transformer2DConditionModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + attention_type: str = "default", + # additional + n_frames: int = 8, + is_temporal: bool = False, + augment_temporal_attention: bool = False, + rotary_emb=False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicConditionalTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + attention_type=attention_type, + # additional + n_frames=n_frames, + is_temporal=is_temporal, + augment_temporal_attention=augment_temporal_attention, + rotary_emb=rotary_emb, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + self.alpha = None + if is_temporal: + self.alpha = nn.Parameter(torch.ones(1)) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + condition_on_first_frame: bool = False, + ): + input_states = hidden_states + input_height, input_width = hidden_states.shape[-2:] + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states, lora_scale) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states, scale=lora_scale) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + # additional + condition_on_first_frame=condition_on_first_frame, + input_height=input_height, + input_width=input_width, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states, scale=lora_scale) + else: + hidden_states = self.proj_out(hidden_states, scale=lora_scale) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if self.alpha is not None: + with torch.no_grad(): + self.alpha.clamp_(0, 1) + + output = self.alpha * input_states + (1 - self.alpha) * output + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +@maybe_allow_in_graph +class BasicConditionalTransformerBlock(nn.Module): + """ transformer block with first frame conditioning """ + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + attention_type: str = "default", + # additional + n_frames: int = 8, + is_temporal: bool = False, + augment_temporal_attention: bool = False, + rotary_emb=False, + ): + super().__init__() + self.n_frames = n_frames + self.only_cross_attention = only_cross_attention + self.augment_temporal_attention = augment_temporal_attention + self.is_temporal = is_temporal + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if not is_temporal: + self.attn1 = ConditionalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + else: + self.attn1 = TemporalConditionalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + # additional + n_frames=n_frames, + rotary_emb=rotary_emb, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + if not is_temporal: + self.attn2 = ConditionalAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = TemporalConditionalAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # additional + n_frames=n_frames, + rotary_emb=rotary_emb, + ) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + condition_on_first_frame: bool = False, + input_height: Optional[int] = None, + input_width: Optional[int] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + if condition_on_first_frame: + first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :] + first_frame_hidden_states = repeat(first_frame_hidden_states, 'b d h -> b f d h', f=self.n_frames) + first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b f d h -> (b f) d h') + first_frame_concat_hidden_states = torch.cat((norm_hidden_states, first_frame_hidden_states), dim=1) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else first_frame_concat_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + elif self.is_temporal and self.augment_temporal_attention: + first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :] + first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b (h w) c -> b h w c', h=input_height, w=input_width) + first_frame_hidden_states = first_frame_hidden_states.permute(0, 3, 1, 2) + padded_first_frame = torch.nn.functional.pad(first_frame_hidden_states, (1, 1, 1, 1), "replicate") + first_frame_windows = padded_first_frame.unfold(2, 3, 1).unfold(3, 3, 1) + mask = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.bool) + adjacent_slices = first_frame_windows[:, :, :, :, mask] + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + adjacent_slices=adjacent_slices, + **cross_attention_kwargs, + ) + else: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + # 2.5 ends + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..96dae2efa99f5554ae6b4a106703bcae6eb22dc8 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet.py @@ -0,0 +1,1371 @@ +import os +import re +from typing import Optional, Tuple, Union, Dict, List, Any +from einops import rearrange, repeat + +import torch +import torch.nn as nn +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.models import ModelMixin +from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from diffusers.models.unet_2d_blocks import UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + PositionNet, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.activations import get_activation +from diffusers.configuration_utils import register_to_config, ConfigMixin +from diffusers.models.modeling_utils import load_state_dict, load_model_dict_into_meta +from diffusers.utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, + HF_HUB_OFFLINE, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + deprecate, + is_accelerate_available, + is_torch_version, + logging, +) +from diffusers import __version__ + +if is_torch_version(">=", "1.9.0"): + _LOW_CPU_MEM_USAGE_DEFAULT = True +else: + _LOW_CPU_MEM_USAGE_DEFAULT = False + + +if is_accelerate_available(): + import accelerate + from accelerate.utils import set_module_tensor_to_device + from accelerate.utils.versions import is_torch_version + + + +from .videoldm_unet_blocks import get_down_block, get_up_block, VideoLDMUNetMidBlock2DCrossAttn + +logger = logging.get_logger(__name__) + + +class VideoLDMUNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + _supports_gradient_checkpointing = True + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", # -> VideoLDMDownBlock + "CrossAttnDownBlock2D", # -> VideoLDMDownBlock + "CrossAttnDownBlock2D", # -> VideoLDMDownBlock + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock2D", + "CrossAttnUpBlock2D", # -> VideoLDMUpBlock + "CrossAttnUpBlock2D", # -> VideoLDMUpBlock + "CrossAttnUpBlock2D", # -> VideoLDMUpBlock + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + # additional + use_temporal: bool = True, + n_frames: int = 8, + n_temp_heads: int = 8, + first_frame_condition_mode: str = "none", + augment_temporal_attention: bool = False, + temp_pos_embedding: str = "sinusoidal", + use_frame_stride_condition: bool = False, + ): + super().__init__() + + rotary_emb = False + if temp_pos_embedding == "rotary": + # from rotary_embedding_torch import RotaryEmbedding + # rotary_emb = RotaryEmbedding(32) + # self.rotary_emb = rotary_emb + rotary_emb = True + self.rotary_emb = rotary_emb + + self.use_temporal = use_temporal + self.augment_temporal_attention = augment_temporal_attention + + assert first_frame_condition_mode in ["none", "concat", "conv2d", "input_only"], f"first_frame_condition_mode: {first_frame_condition_mode} must be one of ['none', 'concat', 'conv2d', 'input_only']" + self.first_frame_condition_mode = first_frame_condition_mode + latent_channels = in_channels + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + self.use_frame_stride_condition = use_frame_stride_condition + if self.use_frame_stride_condition: + self.frame_stride_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + # zero init + nn.init.zeros_(self.frame_stride_embedding.linear_2.weight) + nn.init.zeros_(self.frame_stride_embedding.linear_2.bias) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + # additional + use_temporal=use_temporal, + augment_temporal_attention=augment_temporal_attention, + n_frames=n_frames, + n_temp_heads=n_temp_heads, + first_frame_condition_mode=first_frame_condition_mode, + latent_channels=latent_channels, + rotary_emb=rotary_emb, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = VideoLDMUNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + # additional + use_temporal=use_temporal, + n_frames=n_frames, + first_frame_condition_mode=first_frame_condition_mode, + latent_channels=latent_channels, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + # additional + use_temporal=use_temporal, + augment_temporal_attention=augment_temporal_attention, + n_frames=n_frames, + n_temp_heads=n_temp_heads, + first_frame_condition_mode=first_frame_condition_mode, + latent_channels=latent_channels, + rotary_emb=rotary_emb, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + # additional + first_frame_latents: Optional[torch.Tensor] = None, + frame_stride: Optional[Union[torch.Tensor, float, int]] = None, + ) -> Union[UNet2DConditionOutput, Tuple]: + # reshape video data + assert sample.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={sample.dim()}." + video_length = sample.shape[2] + + if first_frame_latents is not None: + assert self.config.first_frame_condition_mode != "none", "first_frame_latents is not None, but first_frame_condition_mode is 'none'." + + if self.config.first_frame_condition_mode != "none": + sample = torch.cat([first_frame_latents, sample], dim=2) + video_length += 1 + + # copy conditioning embeddings for cross attention + if encoder_hidden_states is not None: + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + sample = rearrange(sample, "b c f h w -> (b f) c h w") + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + if self.use_frame_stride_condition: + if not torch.is_tensor(frame_stride): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + frame_stride = torch.tensor([frame_stride], dtype=dtype, device=sample.device) + elif len(frame_stride.shape) == 0: + frame_stride = frame_stride[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + frame_stride = frame_stride.expand(sample.shape[0]) + + fs_emb = self.time_proj(frame_stride) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + fs_emb = fs_emb.to(dtype=sample.dtype) + + fs_emb = self.frame_stride_embedding(fs_emb, timestep_cond) + emb = emb + fs_emb + + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + first_frame_latents=first_frame_latents, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, first_frame_latents=first_frame_latents,) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + # additional + first_frame_latents=first_frame_latents, + ) + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_block_additional_residuals) > 0 + and sample.shape == down_block_additional_residuals[0].shape + ): + sample += down_block_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + first_frame_latents=first_frame_latents, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + first_frame_latents=first_frame_latents, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length) + if self.config.first_frame_condition_mode != "none": + sample = sample[:, :, 1:, :, :] + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + + kwargs.pop("low_cpu_mem_usage", False) + kwargs.pop("device_map", None) + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = None + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + low_cpu_mem_usage = False + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + user_agent=user_agent, + **kwargs, + ) + + # load model + model_file = None + if from_flax: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_name_or_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warn( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by default the device_map is None and the weights are loaded on the CPU + try: + accelerate.load_checkpoint_and_dispatch( + model, + model_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + ) + except AttributeError as e: + # When using accelerate loading, we do not have the ability to load the state + # dict and rename the weight names manually. Additionally, accelerate skips + # torch loading conventions and directly writes into `module.{_buffers, _parameters}` + # (which look like they should be private variables?), so we can't use the standard hooks + # to rename parameters on load. We need to mimic the original weight names so the correct + # attributes are available. After we have loaded the weights, we convert the deprecated + # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert + # the weights so we don't have to do this again. + + if "'Attention' object has no attribute" in str(e): + logger.warn( + f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" + " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" + " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," + " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," + " please also re-upload it or open a PR on the original repository." + ) + model._temp_convert_self_to_deprecated_attention_blocks() + accelerate.load_checkpoint_and_dispatch( + model, + model_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + ) + model._undo_temp_convert_self_to_deprecated_attention_blocks() + else: + raise e + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + m, u = loading_info["missing_keys"], loading_info["unexpected_keys"] + logger.info(f"### missing keys: {len(m)}; unexpected keys: {len(u)};") + # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") + + spatial_params = [p.numel() if "conv3ds" not in n and "tempo_attns" not in n else 0 for n, p in model.named_parameters()] + tconv_params = [p.numel() if "conv3ds." in n else 0 for n, p in model.named_parameters()] + tattn_params = [p.numel() if "tempo_attns." in n else 0 for n, p in model.named_parameters()] + tffconv_params = [p.numel() if "first_frame_conv." in n else 0 for n, p in model.named_parameters()] + logger.info(f"### First Frame Convolution Layer Parameters: {sum(tffconv_params) / 1e6} M") + logger.info(f"### Spatial UNet Parameters: {sum(spatial_params) / 1e6} M") + logger.info(f"### Temporal Convolution Module Parameters: {sum(tconv_params) / 1e6} M") + logger.info(f"### Temporal Attention Module Parameters: {sum(tattn_params) / 1e6} M") + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + +if __name__ == "__main__": + # test + from diffusers import AutoencoderKL, DDIMScheduler + from transformers import CLIPTextModel, CLIPTokenizer + from consisti2v.pipelines.pipeline_animation import AnimationPipeline + from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline + from consisti2v.utils.util import save_videos_grid + + pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5" + prompt = "apply eye makeup" + first_frame_path = "/ML-A100/home/weiming/datasets/UCF/frames/v_ApplyEyeMakeup_g01_c01_frame_90.jpg" + + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer", use_safetensors=True) + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae", use_safetensors=True) + unet = VideoLDMUNet3DConditionModel.from_pretrained( + pretrained_model_path, + subfolder="unet", + use_safetensors=True + ) + + noise_scheduler_kwargs = { + "num_train_timesteps": 1000, + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "linear", + "steps_offset": 1, + "clip_sample": False, + } + noise_scheduler = DDIMScheduler(**noise_scheduler_kwargs) + # latent = torch.randn(1, 4, 8, 64, 64).to("cuda") + # text_embedding = torch.randn(1, 77, 768).to("cuda") + # timestep = torch.randint(0, 1000, (1,)).to("cuda").squeeze(0) + # output = unet(latent, timestep, text_embedding) + + pipeline = ConditionalAnimationPipeline( + unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, + ).to("cuda") + sample = pipeline( + prompt, + num_inference_steps = 25, + guidance_scale = 8., + video_length = 8, + height = 256, + width = 256, + first_frame_paths = first_frame_path, + ).videos + print(sample.shape) + save_videos_grid(sample, f"samples/videoldm.gif") \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet_blocks.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..63733bb2dda5c3467a5d115ee20bf7fe6b4eab19 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/models/videoldm_unet_blocks.py @@ -0,0 +1,1159 @@ +from typing import Optional, Dict, Tuple, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from diffusers.utils import logging +from diffusers.models.unet_2d_blocks import ( + DownBlock2D, + UpBlock2D +) +from diffusers.models.resnet import ( + ResnetBlock2D, + Downsample2D, + Upsample2D, +) +from diffusers.models.transformer_2d import Transformer2DModelOutput +from diffusers.models.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.activations import get_activation +from diffusers.utils import logging, is_torch_version +from diffusers.utils.import_utils import is_xformers_available +from .videoldm_transformer_blocks import Transformer2DConditionModel + +logger = logging.get_logger(__name__) + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + attention_type="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, + dropout=0.0, + # additional + use_temporal=True, + augment_temporal_attention=False, + n_frames=8, + n_temp_heads=8, + first_frame_condition_mode="none", + latent_channels=4, + rotary_emb=False, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return VideoLDMDownBlock( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + # additional + use_temporal=use_temporal, + n_frames=n_frames, + first_frame_condition_mode=first_frame_condition_mode, + latent_channels=latent_channels + ) + elif down_block_type == "CrossAttnDownBlock2D": + return VideoLDMCrossAttnDownBlock( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + # additional + use_temporal=use_temporal, + augment_temporal_attention=augment_temporal_attention, + n_frames=n_frames, + n_temp_heads=n_temp_heads, + first_frame_condition_mode=first_frame_condition_mode, + latent_channels=latent_channels, + rotary_emb=rotary_emb, + ) + + raise ValueError(f'{down_block_type} does not exist.') + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + attention_type="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, + dropout=0.0, + # additional + use_temporal=True, + augment_temporal_attention=False, + n_frames=8, + n_temp_heads=8, + first_frame_condition_mode="none", + latent_channels=4, + rotary_emb=None, +): + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return VideoLDMUpBlock( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + # additional + use_temporal=use_temporal, + n_frames=n_frames, + first_frame_condition_mode=first_frame_condition_mode, + latent_channels=latent_channels + ) + elif up_block_type == 'CrossAttnUpBlock2D': + return VideoLDMCrossAttnUpBlock( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + # additional + use_temporal=use_temporal, + augment_temporal_attention=augment_temporal_attention, + n_frames=n_frames, + n_temp_heads=n_temp_heads, + first_frame_condition_mode=first_frame_condition_mode, + latent_channels=latent_channels, + rotary_emb=rotary_emb, + ) + + raise ValueError(f'{up_block_type} does not exist.') + + +class TemporalResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + # additional + n_frames=8, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = Conv3DLayer(in_channels, out_channels, n_frames=n_frames) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = Conv3DLayer(out_channels, out_channels, n_frames=n_frames) + + self.nonlinearity = get_activation(non_linearity) + + self.alpha = nn.Parameter(torch.ones(1)) + + def forward(self, input_tensor, temb=None): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + # weighted sum between spatial and temporal features + with torch.no_grad(): + self.alpha.clamp_(0, 1) + + output_tensor = self.alpha * input_tensor + (1 - self.alpha) * output_tensor + + return output_tensor + + +class Conv3DLayer(nn.Conv3d): + def __init__(self, in_dim, out_dim, n_frames): + k, p = (3, 1, 1), (1, 0, 0) + super().__init__(in_channels=in_dim, out_channels=out_dim, kernel_size=k, stride=1, padding=p) + + self.to_3d = Rearrange('(b t) c h w -> b c t h w', t=n_frames) + self.to_2d = Rearrange('b c t h w -> (b t) c h w') + + def forward(self, x): + h = self.to_3d(x) + h = super().forward(h) + out = self.to_2d(h) + return out + + +class IdentityLayer(nn.Identity): + def __init__(self, return_trans2d_output, *args, **kwargs): + super().__init__() + self.return_trans2d_output = return_trans2d_output + + def forward(self, x, *args, **kwargs): + if self.return_trans2d_output: + return Transformer2DModelOutput(sample=x) + else: + return x + + +class VideoLDMCrossAttnDownBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + # additional + use_temporal=True, + augment_temporal_attention=False, + n_frames=8, + n_temp_heads=8, + first_frame_condition_mode="none", + latent_channels=4, + rotary_emb=False, + ): + super().__init__() + + self.use_temporal = use_temporal + + self.n_frames = n_frames + self.first_frame_condition_mode = first_frame_condition_mode + if self.first_frame_condition_mode == "conv2d": + self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1) + + resnets = [] + attentions = [] + + self.n_frames = n_frames + self.n_temp_heads = n_temp_heads + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DConditionModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + # additional + n_frames=n_frames, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + # >>> Temporal Layers >>> + conv3ds = [] + tempo_attns = [] + + for i in range(num_layers): + if self.use_temporal: + conv3ds.append( + TemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + n_frames=n_frames, + ) + ) + + tempo_attns.append( + Transformer2DConditionModel( + n_temp_heads, + out_channels // n_temp_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + # additional + n_frames=n_frames, + is_temporal=True, + augment_temporal_attention=augment_temporal_attention, + rotary_emb=rotary_emb + ) + ) + else: + conv3ds.append(IdentityLayer(return_trans2d_output=False)) + tempo_attns.append(IdentityLayer(return_trans2d_output=True)) + + self.conv3ds = nn.ModuleList(conv3ds) + self.tempo_attns = nn.ModuleList(tempo_attns) + # <<< Temporal Layers <<< + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + # additional + first_frame_latents=None, + ): + condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only") + # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w + if self.first_frame_condition_mode == "conv2d": + hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames) + hidden_height = hidden_states.shape[3] + first_frame_height = first_frame_latents.shape[3] + downsample_ratio = hidden_height / first_frame_height + first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest") + first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2) + hidden_states[:, :, 0:1, :, :] = first_frame_latents + hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames) + + output_states = () + + for resnet, conv3d, attn, tempo_attn in zip(self.resnets, self.conv3ds, self.attentions, self.tempo_attns): + + hidden_states = resnet(hidden_states, temb) + hidden_states = conv3d(hidden_states) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + condition_on_first_frame=condition_on_first_frame, + ).sample + hidden_states = tempo_attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + condition_on_first_frame=False, + ).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class VideoLDMCrossAttnUpBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + # additional + use_temporal=True, + augment_temporal_attention=False, + n_frames=8, + n_temp_heads=8, + first_frame_condition_mode="none", + latent_channels=4, + rotary_emb=False, + ): + super().__init__() + + self.use_temporal = use_temporal + + self.n_frames = n_frames + self.first_frame_condition_mode = first_frame_condition_mode + if self.first_frame_condition_mode == "conv2d": + self.first_frame_conv = nn.Conv2d(latent_channels, prev_output_channel, kernel_size=1) + + resnets = [] + attentions = [] + + self.n_frames = n_frames + self.n_temp_heads = n_temp_heads + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DConditionModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + # additional + n_frames=n_frames, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + # >>> Temporal Layers >>> + conv3ds = [] + tempo_attns = [] + + for i in range(num_layers): + if self.use_temporal: + conv3ds.append( + TemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + n_frames=n_frames, + ) + ) + + tempo_attns.append( + Transformer2DConditionModel( + n_temp_heads, + out_channels // n_temp_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + # additional + n_frames=n_frames, + augment_temporal_attention=augment_temporal_attention, + is_temporal=True, + rotary_emb=rotary_emb, + ) + ) + else: + conv3ds.append(IdentityLayer(return_trans2d_output=False)) + tempo_attns.append(IdentityLayer(return_trans2d_output=True)) + + self.conv3ds = nn.ModuleList(conv3ds) + self.tempo_attns = nn.ModuleList(tempo_attns) + # <<< Temporal Layers <<< + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + # additional + first_frame_latents=None, + ): + condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only") + # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w + if self.first_frame_condition_mode == "conv2d": + hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames) + hidden_height = hidden_states.shape[3] + first_frame_height = first_frame_latents.shape[3] + downsample_ratio = hidden_height / first_frame_height + first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest") + first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2) + hidden_states[:, :, 0:1, :, :] = first_frame_latents + hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames) + + for resnet, conv3d, attn, tempo_attn in zip(self.resnets, self.conv3ds, self.attentions, self.tempo_attns): + + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = conv3d(hidden_states) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + condition_on_first_frame=condition_on_first_frame, + ).sample + hidden_states = tempo_attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + condition_on_first_frame=False, + ).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +class VideoLDMUNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + attention_type="default", + # additional + use_temporal=True, + n_frames: int = 8, + first_frame_condition_mode="none", + latent_channels=4, + ): + super().__init__() + + self.use_temporal = use_temporal + + self.n_frames = n_frames + self.first_frame_condition_mode = first_frame_condition_mode + if self.first_frame_condition_mode == "conv2d": + self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1) + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + if self.use_temporal: + conv3ds = [ + TemporalResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + n_frames=n_frames, + ) + ] + else: + conv3ds = [IdentityLayer(return_trans2d_output=False)] + + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DConditionModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + # additional + n_frames=n_frames, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if self.use_temporal: + conv3ds.append( + TemporalResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + n_frames=n_frames, + ) + ) + else: + conv3ds.append(IdentityLayer(return_trans2d_output=False)) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.conv3ds = nn.ModuleList(conv3ds) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + # additional + first_frame_latents=None, + ) -> torch.FloatTensor: + condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only") + # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w + if self.first_frame_condition_mode == "conv2d": + hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames) + hidden_height = hidden_states.shape[3] + first_frame_height = first_frame_latents.shape[3] + downsample_ratio = hidden_height / first_frame_height + first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest") + first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2) + hidden_states[:, :, 0:1, :, :] = first_frame_latents + hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames) + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + hidden_states = self.conv3ds[0](hidden_states) + for attn, resnet, conv3d in zip(self.attentions, self.resnets[1:], self.conv3ds[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + # additional + condition_on_first_frame=condition_on_first_frame, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = conv3d(hidden_states) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + # additional + condition_on_first_frame=condition_on_first_frame, + )[0] + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = conv3d(hidden_states) + + return hidden_states + + +class VideoLDMDownBlock(DownBlock2D): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + # additional + use_temporal=True, + n_frames: int = 8, + first_frame_condition_mode="none", + latent_channels=4, + ): + super().__init__( + in_channels, + out_channels, + temb_channels, + dropout, + num_layers, + resnet_eps, + resnet_time_scale_shift, + resnet_act_fn, + resnet_groups, + resnet_pre_norm, + output_scale_factor, + add_downsample, + downsample_padding,) + + self.use_temporal = use_temporal + + self.n_frames = n_frames + self.first_frame_condition_mode = first_frame_condition_mode + if self.first_frame_condition_mode == "conv2d": + self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1) + + # >>> Temporal Layers >>> + conv3ds = [] + for i in range(num_layers): + if self.use_temporal: + conv3ds.append( + TemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + n_frames=n_frames, + ) + ) + else: + conv3ds.append(IdentityLayer(return_trans2d_output=False)) + self.conv3ds = nn.ModuleList(conv3ds) + # <<< Temporal Layers <<< + + def forward(self, hidden_states, temb=None, scale: float = 1, first_frame_latents=None): + # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w + if self.first_frame_condition_mode == "conv2d": + hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames) + hidden_height = hidden_states.shape[3] + first_frame_height = first_frame_latents.shape[3] + downsample_ratio = hidden_height / first_frame_height + first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest") + first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2) + hidden_states[:, :, 0:1, :, :] = first_frame_latents + hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames) + + output_states = () + + for resnet, conv3d in zip(self.resnets, self.conv3ds): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + hidden_states = conv3d(hidden_states) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class VideoLDMUpBlock(UpBlock2D): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + # additional + use_temporal=True, + n_frames: int = 8, + first_frame_condition_mode="none", + latent_channels=4, + ): + super().__init__( + in_channels, + prev_output_channel, + out_channels, + temb_channels, + dropout, + num_layers, + resnet_eps, + resnet_time_scale_shift, + resnet_act_fn, + resnet_groups, + resnet_pre_norm, + output_scale_factor, + add_upsample, + ) + + self.use_temporal = use_temporal + + self.n_frames = n_frames + self.first_frame_condition_mode = first_frame_condition_mode + if self.first_frame_condition_mode == "conv2d": + self.first_frame_conv = nn.Conv2d(latent_channels, prev_output_channel, kernel_size=1) + + # >>> Temporal Layers >>> + conv3ds = [] + for i in range(num_layers): + if self.use_temporal: + conv3ds.append( + TemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + n_frames=n_frames, + ) + ) + else: + conv3ds.append(IdentityLayer(return_trans2d_output=False)) + + self.conv3ds = nn.ModuleList(conv3ds) + # <<< Temporal Layers <<< + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1, first_frame_latents=None): + # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w + if self.first_frame_condition_mode == "conv2d": + hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames) + hidden_height = hidden_states.shape[3] + first_frame_height = first_frame_latents.shape[3] + downsample_ratio = hidden_height / first_frame_height + first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest") + first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2) + hidden_states[:, :, 0:1, :, :] = first_frame_latents + hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames) + + for resnet, conv3d in zip(self.resnets, self.conv3ds): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + hidden_states = conv3d(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + return hidden_states \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/__init__.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_autoregress_animation.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_autoregress_animation.py new file mode 100644 index 0000000000000000000000000000000000000000..005a5e1e407a808cd1db7d316585187792eb07a8 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_autoregress_animation.py @@ -0,0 +1,615 @@ +# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py + +import inspect +from typing import Callable, List, Optional, Union +from dataclasses import dataclass + +import math +import numpy as np +import torch +from tqdm import tqdm + +from torchvision import transforms as T +from PIL import Image + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging, BaseOutput + +from einops import rearrange, repeat + +from ..models.unet import UNet3DConditionModel +from ..utils.frameinit_utils import freq_mix_3d, get_freq_filter + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21 +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +@dataclass +class AnimationPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class AutoregressiveAnimationPipeline(DiffusionPipeline): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.freq_filter = None + + @torch.no_grad() + def init_filter(self, video_length, height, width, filter_params): + # initialize frequency filter for noise reinitialization + batch_size = 1 + num_channels_latents = self.unet.config.in_channels + filter_shape = [ + batch_size, + num_channels_latents, + video_length, + height // self.vae_scale_factor, + width // self.vae_scale_factor + ] + # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params) + self.freq_filter = get_freq_filter( + filter_shape, + device=self._execution_device, + filter_type=filter_params.method, + n=filter_params.n if filter_params.method=="butterworth" else None, + d_s=filter_params.d_s, + d_t=filter_params.d_t + ) + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance is not None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance == "text": + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + elif do_classifier_free_guidance == "both": + text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents, first_frames=None): + video_length = latents.shape[2] + latents = 1 / self.vae.config.scaling_factor * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + # video = self.vae.decode(latents).sample + video = [] + for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config): + video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + + if first_frames is not None: + first_frames = first_frames.unsqueeze(2) + video = torch.cat([first_frames, video], dim=2) + + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)): + raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + # shape = shape + shape = (1,) + shape[1:] + if noise_sampling_method == "vanilla": + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + elif noise_sampling_method == "pyoco_mixed": + base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + latents = [] + noise_alpha_squared = noise_alpha ** 2 + for i in range(batch_size): + base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) + latents.append(base_latent + ind_latent) + elif noise_sampling_method == "pyoco_progressive": + latents = [] + noise_alpha_squared = noise_alpha ** 2 + for i in range(batch_size): + latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) + for j in range(1, video_length): + latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :] + latents.append(latent) + latents = torch.cat(latents, dim=0).to(device) + else: + if noise_sampling_method == "vanilla": + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + elif noise_sampling_method == "pyoco_mixed": + noise_alpha_squared = noise_alpha ** 2 + base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) + latents = base_latents + ind_latents + elif noise_sampling_method == "pyoco_progressive": + noise_alpha_squared = noise_alpha ** 2 + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) + ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) + for j in range(1, video_length): + latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :] + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale_txt: float = 7.5, + guidance_scale_img: float = 2.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + # additional + first_frame_paths: Optional[Union[str, List[str]]] = None, + first_frames: Optional[torch.FloatTensor] = None, + noise_sampling_method: str = "vanilla", + noise_alpha: float = 1.0, + guidance_rescale: float = 0.0, + frame_stride: Optional[int] = None, + autoregress_steps: int = 3, + use_frameinit: bool = False, + frameinit_noise_level: int = 999, + **kwargs, + ): + if first_frame_paths is not None and first_frames is not None: + raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.") + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps, first_frame_paths) + + # Define call parameters + # batch_size = 1 if isinstance(prompt, str) else len(prompt) + batch_size = 1 + if latents is not None: + batch_size = latents.shape[0] + if isinstance(prompt, list): + batch_size = len(prompt) + first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames + if first_frame_input is not None: + assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length" + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = None + # two guidance mode: text and text+image + if guidance_scale_txt > 1.0: + do_classifier_free_guidance = "text" + if guidance_scale_img > 1.0: + do_classifier_free_guidance = "both" + + # Encode input prompt + prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size + if negative_prompt is not None: + negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size + text_embeddings = self._encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # Encode input first frame + first_frame_latents = None + if first_frame_paths is not None: + first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size + img_transform = T.Compose([ + T.ToTensor(), + T.Resize(height, antialias=None), + T.CenterCrop((height, width)), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + first_frames = [] + for first_frame_path in first_frame_paths: + first_frame = Image.open(first_frame_path).convert('RGB') + first_frame = img_transform(first_frame).unsqueeze(0) + first_frames.append(first_frame) + first_frames = torch.cat(first_frames, dim=0) + if first_frames is not None: + first_frames = first_frames.to(device, dtype=self.vae.dtype) + first_frame_latents = self.vae.encode(first_frames).latent_dist + first_frame_latents = first_frame_latents.sample() + first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w + first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt) + first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt) + + full_video_latent = torch.zeros(batch_size * num_videos_per_prompt, self.unet.config.in_channels, video_length * autoregress_steps - autoregress_steps + 1, height // self.vae_scale_factor, width // self.vae_scale_factor, device=device, dtype=self.vae.dtype) + + start_idx = 0 + for ar_step in range(autoregress_steps): + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + noise_sampling_method, + noise_alpha, + ) + latents_dtype = latents.dtype + + if use_frameinit: + current_diffuse_timestep = frameinit_noise_level # diffuse to noise level + diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep)) + diffuse_timesteps = diffuse_timesteps.long() + first_frames_static_vid = repeat(first_frame_latents, "b c h w -> b c t h w", t=video_length) + z_T = self.scheduler.add_noise( + original_samples=first_frames_static_vid.to(device), + noise=latents.to(device), + timesteps=diffuse_timesteps.to(device) + ) + latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents, LPF=self.freq_filter) + latents = latents.to(dtype=latents_dtype) + + if first_frame_latents is not None: + first_frame_noisy_latent = latents[:, :, 0, :, :] + latents = latents[:, :, 1:, :, :] + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + if do_classifier_free_guidance is None: + latent_model_input = latents + elif do_classifier_free_guidance == "text": + latent_model_input = torch.cat([latents] * 2) + elif do_classifier_free_guidance == "both": + latent_model_input = torch.cat([latents] * 3) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if first_frame_latents is not None: + if do_classifier_free_guidance is None: + first_frame_latents_input = first_frame_latents + elif do_classifier_free_guidance == "text": + first_frame_latents_input = torch.cat([first_frame_latents] * 2) + elif do_classifier_free_guidance == "both": + first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents]) + + first_frame_latents_input = first_frame_latents_input.unsqueeze(2) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype) + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype) + # noise_pred = [] + # import pdb + # pdb.set_trace() + # for batch_idx in range(latent_model_input.shape[0]): + # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype) + # noise_pred.append(noise_pred_single) + # noise_pred = torch.cat(noise_pred) + + # perform guidance + if do_classifier_free_guidance: + if do_classifier_free_guidance == "text": + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond) + elif do_classifier_free_guidance == "both": + noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3) + noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # currently only support text guidance + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # Post-processing + + latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2) + first_frame_latents = latents[:, :, -1, :, :] + full_video_latent[:, :, start_idx:start_idx + video_length, :, :] = latents + + latents = None + start_idx += (video_length - 1) + + # video = self.decode_latents(latents, first_frames) + video = self.decode_latents(full_video_latent) + + # Convert to tensor + if output_type == "tensor": + video = torch.from_numpy(video) + + if not return_dict: + return video + + return AnimationPipelineOutput(videos=video) diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_conditional_animation.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_conditional_animation.py new file mode 100644 index 0000000000000000000000000000000000000000..37ec8d5b260df0f79e0b104b8c05342525b4eee1 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/pipelines/pipeline_conditional_animation.py @@ -0,0 +1,695 @@ +# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py + +import inspect +from typing import Callable, List, Optional, Union +from dataclasses import dataclass + +import math +import numpy as np +import torch +from tqdm import tqdm + +from torchvision import transforms as T +from torchvision.transforms import functional as F +from PIL import Image + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging, BaseOutput + +from einops import rearrange, repeat + +from ..models.videoldm_unet import VideoLDMUNet3DConditionModel + +from ..utils.frameinit_utils import get_freq_filter, freq_mix_3d + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21 +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + +def pan_right(image, num_frames=16, crop_width=256): + frames = [] + height, width = image.shape[-2:] + + for i in range(num_frames): + # Calculate the start position of the crop + start_x = int((width - crop_width) * (i / num_frames)) + crop = F.crop(image, 0, start_x, height, crop_width) + frames.append(crop.unsqueeze(0)) + + return torch.cat(frames, dim=0) + + +def pan_left(image, num_frames=16, crop_width=256): + frames = [] + height, width = image.shape[-2:] + + for i in range(num_frames): + # Start position moves from right to left + start_x = int((width - crop_width) * (1 - (i / num_frames))) + crop = F.crop(image, 0, start_x, height, crop_width) + frames.append(crop.unsqueeze(0)) + + return torch.cat(frames, dim=0) + + +def zoom_in(image, num_frames=16, crop_width=256, ratio=1.5): + frames = [] + height, width = image.shape[-2:] + max_crop_size = min(width, height) + + for i in range(num_frames): + # Calculate the size of the crop + crop_size = max_crop_size - int((max_crop_size - max_crop_size // ratio) * (i / num_frames)) + start_x = (width - crop_size) // 2 + start_y = (height - crop_size) // 2 + crop = F.crop(image, start_y, start_x, crop_size, crop_size) + resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size + frames.append(resized_crop.unsqueeze(0)) + + return torch.cat(frames, dim=0) + + +def zoom_out(image, num_frames=16, crop_width=256, ratio=1.5): + frames = [] + height, width = image.shape[-2:] + min_crop_size = min(width, height) // ratio # Starting from a quarter of the size + + for i in range(num_frames): + # Calculate the size of the crop + crop_size = min_crop_size + int((min(width, height) - min_crop_size) * (i / num_frames)) + start_x = (width - crop_size) // 2 + start_y = (height - crop_size) // 2 + crop = F.crop(image, start_y, start_x, crop_size, crop_size) + resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size + frames.append(resized_crop.unsqueeze(0)) + + return torch.cat(frames, dim=0) + + +@dataclass +class AnimationPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class ConditionalAnimationPipeline(DiffusionPipeline): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: VideoLDMUNet3DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.freq_filter = None + + @torch.no_grad() + def init_filter(self, video_length, height, width, filter_params): + # initialize frequency filter for noise reinitialization + batch_size = 1 + num_channels_latents = self.unet.config.in_channels + filter_shape = [ + batch_size, + num_channels_latents, + video_length, + height // self.vae_scale_factor, + width // self.vae_scale_factor + ] + # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params) + self.freq_filter = get_freq_filter( + filter_shape, + device=self._execution_device, + filter_type=filter_params.method, + n=filter_params.n if filter_params.method=="butterworth" else None, + d_s=filter_params.d_s, + d_t=filter_params.d_t + ) + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance is not None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance == "text": + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + elif do_classifier_free_guidance == "both": + text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents, first_frames=None): + video_length = latents.shape[2] + latents = 1 / self.vae.config.scaling_factor * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + # video = self.vae.decode(latents).sample + video = [] + for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config): + video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + + if first_frames is not None: + first_frames = first_frames.unsqueeze(2) + video = torch.cat([first_frames, video], dim=2) + + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)): + raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + # shape = shape + shape = (1,) + shape[1:] + if noise_sampling_method == "vanilla": + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + elif noise_sampling_method == "pyoco_mixed": + base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + latents = [] + noise_alpha_squared = noise_alpha ** 2 + for i in range(batch_size): + base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) + latents.append(base_latent + ind_latent) + elif noise_sampling_method == "pyoco_progressive": + latents = [] + noise_alpha_squared = noise_alpha ** 2 + for i in range(batch_size): + latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) + for j in range(1, video_length): + latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :] + latents.append(latent) + latents = torch.cat(latents, dim=0).to(device) + else: + if noise_sampling_method == "vanilla": + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + elif noise_sampling_method == "pyoco_mixed": + noise_alpha_squared = noise_alpha ** 2 + base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) + latents = base_latents + ind_latents + elif noise_sampling_method == "pyoco_progressive": + noise_alpha_squared = noise_alpha ** 2 + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) + ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) + for j in range(1, video_length): + latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :] + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale_txt: float = 7.5, + guidance_scale_img: float = 2.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + # additional + first_frame_paths: Optional[Union[str, List[str]]] = None, + first_frames: Optional[torch.FloatTensor] = None, + noise_sampling_method: str = "vanilla", + noise_alpha: float = 1.0, + guidance_rescale: float = 0.0, + frame_stride: Optional[int] = None, + use_frameinit: bool = False, + frameinit_noise_level: int = 999, + camera_motion: str = None, + **kwargs, + ): + if first_frame_paths is not None and first_frames is not None: + raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.") + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps, first_frame_paths) + + # Define call parameters + # batch_size = 1 if isinstance(prompt, str) else len(prompt) + batch_size = 1 + if latents is not None: + batch_size = latents.shape[0] + if isinstance(prompt, list): + batch_size = len(prompt) + first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames + if first_frame_input is not None: + assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length" + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = None + # two guidance mode: text and text+image + if guidance_scale_txt > 1.0: + do_classifier_free_guidance = "text" + if guidance_scale_img > 1.0: + do_classifier_free_guidance = "both" + + # Encode input prompt + prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size + if negative_prompt is not None: + negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size + text_embeddings = self._encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # Encode input first frame + first_frame_latents = None + if first_frame_paths is not None: + first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size + if camera_motion is None: + img_transform = T.Compose([ + T.ToTensor(), + T.Resize(height, antialias=None), + T.CenterCrop((height, width)), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif camera_motion == "pan_left" or camera_motion == "pan_right": + img_transform = T.Compose([ + T.ToTensor(), + T.Resize(height, antialias=None), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif camera_motion == "zoom_out" or camera_motion == "zoom_in": + img_transform = T.Compose([ + T.ToTensor(), + T.Resize(height * 2, antialias=None), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + first_frames = [] + for first_frame_path in first_frame_paths: + first_frame = Image.open(first_frame_path).convert('RGB') + first_frame = img_transform(first_frame) + if camera_motion is not None: + if camera_motion == "pan_left": + first_frame = pan_left(first_frame, num_frames=video_length, crop_width=width) + elif camera_motion == "pan_right": + first_frame = pan_right(first_frame, num_frames=video_length, crop_width=width) + elif camera_motion == "zoom_in": + first_frame = zoom_in(first_frame, num_frames=video_length, crop_width=width) + elif camera_motion == "zoom_out": + first_frame = zoom_out(first_frame, num_frames=video_length, crop_width=width) + else: + raise NotImplementedError(f"camera_motion: {camera_motion} is not implemented.") + first_frames.append(first_frame.unsqueeze(0)) + first_frames = torch.cat(first_frames, dim=0) + if first_frames is not None: + first_frames = first_frames.to(device, dtype=self.vae.dtype) + if camera_motion is not None: + first_frames = rearrange(first_frames, "b f c h w -> (b f) c h w") + first_frame_latents = self.vae.encode(first_frames).latent_dist + first_frame_latents = first_frame_latents.sample() + first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w + first_frame_static_vid = rearrange(first_frame_latents, "(b f) c h w -> b c f h w", f=video_length if camera_motion is not None else 1) + first_frame_latents = first_frame_static_vid[:, :, 0, :, :] + first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt) + first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt) + + if use_frameinit and camera_motion is None: + first_frame_static_vid = repeat(first_frame_static_vid, "b c 1 h w -> b c t h w", t=video_length) + + # self._progress_bar_config = {} + # vid = self.decode_latents(first_frame_static_vid) + # vid = torch.from_numpy(vid) + # from ..utils.util import save_videos_grid + # save_videos_grid(vid, "samples/debug/camera_motion/first_frame_static_vid.mp4", fps=8) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + noise_sampling_method, + noise_alpha, + ) + latents_dtype = latents.dtype + + if use_frameinit: + current_diffuse_timestep = frameinit_noise_level # diffuse to t noise level + diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep)) + diffuse_timesteps = diffuse_timesteps.long() + z_T = self.scheduler.add_noise( + original_samples=first_frame_static_vid.to(device), + noise=latents.to(device), + timesteps=diffuse_timesteps.to(device) + ) + latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents.to(dtype=torch.float32), LPF=self.freq_filter) + latents = latents.to(dtype=latents_dtype) + + if first_frame_latents is not None: + first_frame_noisy_latent = latents[:, :, 0, :, :] + latents = latents[:, :, 1:, :, :] + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + if do_classifier_free_guidance is None: + latent_model_input = latents + elif do_classifier_free_guidance == "text": + latent_model_input = torch.cat([latents] * 2) + elif do_classifier_free_guidance == "both": + latent_model_input = torch.cat([latents] * 3) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if first_frame_latents is not None: + if do_classifier_free_guidance is None: + first_frame_latents_input = first_frame_latents + elif do_classifier_free_guidance == "text": + first_frame_latents_input = torch.cat([first_frame_latents] * 2) + elif do_classifier_free_guidance == "both": + first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents]) + + first_frame_latents_input = first_frame_latents_input.unsqueeze(2) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype) + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype) + + # perform guidance + if do_classifier_free_guidance: + if do_classifier_free_guidance == "text": + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond) + elif do_classifier_free_guidance == "both": + noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3) + noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # currently only support text guidance + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # Post-processing + latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2) + # video = self.decode_latents(latents, first_frames) + video = self.decode_latents(latents) + + # Convert to tensor + if output_type == "tensor": + video = torch.from_numpy(video) + + if not return_dict: + return video + + return AnimationPipelineOutput(videos=video) diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/utils/__init__.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/utils/frameinit_utils.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/utils/frameinit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52de1d99b2a4abb3a2e9d276b4b848d48761662c --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/utils/frameinit_utils.py @@ -0,0 +1,142 @@ +# modified from https://github.com/TianxingWu/FreeInit/blob/master/freeinit_utils.py +import torch +import torch.fft as fft +import math + + +def freq_mix_3d(x, noise, LPF): + """ + Noise reinitialization. + + Args: + x: diffused latent + noise: randomly sampled noise + LPF: low pass filter + """ + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + HPF = 1 - LPF + x_freq_low = x_freq * LPF + noise_freq_high = noise_freq * HPF + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed + + +def get_freq_filter(shape, device, filter_type, n, d_s, d_t): + """ + Form the frequency filter for noise reinitialization. + + Args: + shape: shape of latent (B, C, T, H, W) + filter_type: type of the freq filter + n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + if filter_type == "gaussian": + return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + elif filter_type == "ideal": + return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + elif filter_type == "box": + return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + elif filter_type == "butterworth": + return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) + else: + raise NotImplementedError + + +def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the gaussian low pass filter mask. + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s==0 or d_t==0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) + mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) + return mask + + +def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): + """ + Compute the butterworth low pass filter mask. + + Args: + shape: shape of the filter (volume) + n: order of the filter, larger n ~ ideal, smaller n ~ gaussian + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s==0 or d_t==0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) + mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) + return mask + + +def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the ideal low pass filter mask. + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s==0 or d_t==0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) + mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 + return mask + + +def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the ideal low pass filter mask (approximated version). + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s==0 or d_t==0: + return mask + + threshold_s = round(int(H // 2) * d_s) + threshold_t = round(T // 2 * d_t) + + cframe, crow, ccol = T // 2, H // 2, W //2 + mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 + + return mask \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/consisti2v/utils/util.py b/src/videogen_hub/pipelines/consisti2v/consisti2v/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a946a5b5b2f814df409bdbda9c8bb83096a73ae3 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/consisti2v/utils/util.py @@ -0,0 +1,165 @@ +import os +import imageio +import numpy as np +from typing import Union + +import torch +import torchvision +import torch.distributed as dist +import wandb + +from tqdm import tqdm +from einops import rearrange + +from torchmetrics.image.fid import _compute_fid + + +def zero_rank_print(s): + if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, wandb=False, global_step=0, format="gif"): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + if wandb: + wandb_video = wandb.Video(outputs, fps=fps) + wandb.log({"val_videos": wandb_video}, step=global_step) + + os.makedirs(os.path.dirname(path), exist_ok=True) + if format == "gif": + imageio.mimsave(path, outputs, fps=fps) + elif format == "mp4": + torchvision.io.write_video(path, np.array(outputs), fps=fps, video_codec='h264', options={'crf': '10'}) + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + + +def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): + timestep, next_timestep = min( + timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep + alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + +def get_noise_pred_single(latents, t, context, first_frame_latents, frame_stride, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context, first_frame_latents=first_frame_latents, frame_stride=frame_stride).sample + return noise_pred + + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt, first_frame_latents, frame_stride): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent, t, cond_embeddings, first_frame_latents, frame_stride, pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt="", first_frame_latents=None, frame_stride=3): + ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt, first_frame_latents, frame_stride) + return ddim_latents + + +def compute_fid(real_features, fake_features, num_features, device): + orig_dtype = real_features.dtype + + mx_num_feats = (num_features, num_features) + real_features_sum = torch.zeros(num_features).double().to(device) + real_features_cov_sum = torch.zeros(mx_num_feats).double().to(device) + real_features_num_samples = torch.tensor(0).long().to(device) + + fake_features_sum = torch.zeros(num_features).double().to(device) + fake_features_cov_sum = torch.zeros(mx_num_feats).double().to(device) + fake_features_num_samples = torch.tensor(0).long().to(device) + + real_features = real_features.double() + fake_features = fake_features.double() + + real_features_sum += real_features.sum(dim=0) + real_features_cov_sum += real_features.t().mm(real_features) + real_features_num_samples += real_features.shape[0] + + fake_features_sum += fake_features.sum(dim=0) + fake_features_cov_sum += fake_features.t().mm(fake_features) + fake_features_num_samples += fake_features.shape[0] + + """Calculate FID score based on accumulated extracted features from the two distributions.""" + if real_features_num_samples < 2 or fake_features_num_samples < 2: + raise RuntimeError("More than one sample is required for both the real and fake distributed to compute FID") + mean_real = (real_features_sum / real_features_num_samples).unsqueeze(0) + mean_fake = (fake_features_sum / fake_features_num_samples).unsqueeze(0) + + cov_real_num = real_features_cov_sum - real_features_num_samples * mean_real.t().mm(mean_real) + cov_real = cov_real_num / (real_features_num_samples - 1) + cov_fake_num = fake_features_cov_sum - fake_features_num_samples * mean_fake.t().mm(mean_fake) + cov_fake = cov_fake_num / (fake_features_num_samples - 1) + return _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(orig_dtype) + + +def compute_inception_score(gen_probs, num_splits=10): + num_gen = gen_probs.shape[0] + gen_probs = gen_probs.detach().cpu().numpy() + scores = [] + np.random.RandomState(42).shuffle(gen_probs) + for i in range(num_splits): + part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] + kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) + kl = np.mean(np.sum(kl, axis=1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)), float(np.std(scores)) + # idx = torch.randperm(features.shape[0]) + # features = features[idx] + # # calculate probs and logits + # prob = features.softmax(dim=1) + # log_prob = features.log_softmax(dim=1) + + # # split into groups + # prob = prob.chunk(splits, dim=0) + # log_prob = log_prob.chunk(splits, dim=0) + + # # calculate score per split + # mean_prob = [p.mean(dim=0, keepdim=True) for p in prob] + # kl_ = [p * (log_p - m_p.log()) for p, log_p, m_p in zip(prob, log_prob, mean_prob)] + # kl_ = [k.sum(dim=1).mean().exp() for k in kl_] + # kl = torch.stack(kl_) + + # return mean and std + # return kl.mean(), kl.std() \ No newline at end of file diff --git a/src/videogen_hub/pipelines/consisti2v/scripts/__init__.py b/src/videogen_hub/pipelines/consisti2v/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/consisti2v/scripts/animate.py b/src/videogen_hub/pipelines/consisti2v/scripts/animate.py new file mode 100644 index 0000000000000000000000000000000000000000..50e36e0db245ec93823b27d58ab958adfd24b641 --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/scripts/animate.py @@ -0,0 +1,247 @@ +import argparse +import datetime +import random +import os +import logging +from omegaconf import OmegaConf + +import torch + +import diffusers +from diffusers import AutoencoderKL, DDIMScheduler + +from transformers import CLIPTextModel, CLIPTokenizer + +from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel +from consisti2v.pipelines.pipeline_conditional_animation import ( + ConditionalAnimationPipeline, +) +from consisti2v.utils.util import save_videos_grid +from diffusers.utils.import_utils import is_xformers_available + + +def main(args, config): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + diffusers.utils.logging.set_verbosity_info() + + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + savedir = f"{config.output_dir}/{config.output_name}-{time_str}" + os.makedirs(savedir) + + samples = [] + sample_idx = 0 + + ### >>> create validation pipeline >>> ### + if config.pipeline_pretrained_path is None: + noise_scheduler = DDIMScheduler( + **OmegaConf.to_container(config.noise_scheduler_kwargs) + ) + tokenizer = CLIPTokenizer.from_pretrained( + config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True + ) + text_encoder = CLIPTextModel.from_pretrained( + config.pretrained_model_path, subfolder="text_encoder" + ) + vae = AutoencoderKL.from_pretrained( + config.pretrained_model_path, subfolder="vae", use_safetensors=True + ) + unet = VideoLDMUNet3DConditionModel.from_pretrained( + config.pretrained_model_path, + subfolder="unet", + variant=config.unet_additional_kwargs["variant"], + temp_pos_embedding=config.unet_additional_kwargs["temp_pos_embedding"], + augment_temporal_attention=config.unet_additional_kwargs[ + "augment_temporal_attention" + ], + use_temporal=True, + n_frames=config.sampling_kwargs["n_frames"], + n_temp_heads=config.unet_additional_kwargs["n_temp_heads"], + first_frame_condition_mode=config.unet_additional_kwargs[ + "first_frame_condition_mode" + ], + use_frame_stride_condition=config.unet_additional_kwargs[ + "use_frame_stride_condition" + ], + use_safetensors=True, + ) + + # 1. unet ckpt + if config.unet_path is not None: + if os.path.isdir(config.unet_path): + unet_dict = VideoLDMUNet3DConditionModel.from_pretrained( + config.unet_path + ) + m, u = unet.load_state_dict(unet_dict.state_dict(), strict=False) + assert len(u) == 0 + del unet_dict + else: + checkpoint_dict = torch.load(config.unet_path, map_location="cpu") + state_dict = ( + checkpoint_dict["state_dict"] + if "state_dict" in checkpoint_dict + else checkpoint_dict + ) + if config.unet_ckpt_prefix is not None: + state_dict = { + k.replace(config.unet_ckpt_prefix, ""): v + for k, v in state_dict.items() + } + m, u = unet.load_state_dict(state_dict, strict=False) + assert len(u) == 0 + + if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: + unet.enable_xformers_memory_efficient_attention() + + pipeline = ConditionalAnimationPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=noise_scheduler, + ) + + else: + pipeline = ConditionalAnimationPipeline.from_pretrained( + config.pipeline_pretrained_path + ) + + pipeline.to("cuda") + + # (frameinit) initialize frequency filter for noise reinitialization ------------- + if config.frameinit_kwargs.enable: + pipeline.init_filter( + width=config.sampling_kwargs.width, + height=config.sampling_kwargs.height, + video_length=config.sampling_kwargs.n_frames, + filter_params=config.frameinit_kwargs.filter_params, + ) + # ------------------------------------------------------------------------------- + ### <<< create validation pipeline <<< ### + + if args.prompt is not None: + prompts = [args.prompt] + n_prompts = [args.n_prompt] + first_frame_paths = [args.path_to_first_frame] + random_seeds = [int(args.seed)] if args.seed != "random" else "random" + else: + prompt_config = OmegaConf.load(args.prompt_config) + prompts = prompt_config.prompts + n_prompts = ( + list(prompt_config.n_prompts) * len(prompts) + if len(prompt_config.n_prompts) == 1 + else prompt_config.n_prompts + ) + first_frame_paths = prompt_config.path_to_first_frames + random_seeds = prompt_config.seeds + + if random_seeds == "random": + random_seeds = [random.randint(0, 1e5) for _ in range(len(prompts))] + else: + random_seeds = ( + [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) + ) + random_seeds = ( + random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds + ) + + config.prompt_kwargs = OmegaConf.create( + { + "random_seeds": [], + "prompts": prompts, + "n_prompts": n_prompts, + "first_frame_paths": first_frame_paths, + } + ) + for prompt_idx, (prompt, n_prompt, first_frame_path, random_seed) in enumerate( + zip(prompts, n_prompts, first_frame_paths, random_seeds) + ): + # manually set random seed for reproduction + if random_seed != -1: + torch.manual_seed(random_seed) + else: + torch.seed() + config.prompt_kwargs.random_seeds.append(torch.initial_seed()) + + print(f"current seed: {torch.initial_seed()}") + print(f"sampling {prompt} ...") + sample = pipeline( + prompt, + negative_prompt=n_prompt, + first_frame_paths=first_frame_path, + num_inference_steps=config.sampling_kwargs.steps, + guidance_scale_txt=config.sampling_kwargs.guidance_scale_txt, + guidance_scale_img=config.sampling_kwargs.guidance_scale_img, + width=config.sampling_kwargs.width, + height=config.sampling_kwargs.height, + video_length=config.sampling_kwargs.n_frames, + noise_sampling_method=config.unet_additional_kwargs[ + "noise_sampling_method" + ], + noise_alpha=float(config.unet_additional_kwargs["noise_alpha"]), + eta=config.sampling_kwargs.ddim_eta, + frame_stride=config.sampling_kwargs.frame_stride, + guidance_rescale=config.sampling_kwargs.guidance_rescale, + num_videos_per_prompt=config.sampling_kwargs.num_videos_per_prompt, + use_frameinit=config.frameinit_kwargs.enable, + frameinit_noise_level=config.frameinit_kwargs.noise_level, + camera_motion=config.frameinit_kwargs.camera_motion, + ).videos + samples.append(sample) + + prompt = "-".join((prompt.replace("/", "").split(" ")[:10])).replace(":", "") + if sample.shape[0] > 1: + for cnt, samp in enumerate(sample): + save_videos_grid( + samp.unsqueeze(0), + f"{savedir}/sample/{sample_idx}-{cnt + 1}-{prompt}.{args.format}", + format=args.format, + ) + else: + save_videos_grid( + sample, + f"{savedir}/sample/{sample_idx}-{prompt}.{args.format}", + format=args.format, + ) + print(f"save to {savedir}/sample/{prompt}.{args.format}") + + sample_idx += 1 + + samples = torch.concat(samples) + # save_videos_grid(samples, f"{savedir}/sample.{args.format}", n_rows=4, format=args.format) + + # OmegaConf.save(config, f"{savedir}/config.yaml") + + # if args.save_model: + # pipeline.save_pretrained(f"{savedir}/model") + + return samples + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--inference_config", type=str, default="configs/inference/inference.yaml" + ) + parser.add_argument("--prompt", "-p", type=str, default=None) + parser.add_argument("--n_prompt", "-n", type=str, default="") + parser.add_argument("--seed", type=str, default="random") + parser.add_argument("--path_to_first_frame", "-f", type=str, default=None) + parser.add_argument( + "--prompt_config", type=str, default="configs/prompts/default.yaml" + ) + parser.add_argument("--format", type=str, default="mp4", choices=["gif", "mp4"]) + parser.add_argument("--save_model", action="store_true") + parser.add_argument("optional_args", nargs="*", default=[]) + args = parser.parse_args() + + config = OmegaConf.load(args.inference_config) + + if args.optional_args: + modified_config = OmegaConf.from_dotlist(args.optional_args) + config = OmegaConf.merge(config, modified_config) + + main(args, config) diff --git a/src/videogen_hub/pipelines/consisti2v/scripts/animate_autoregress.py b/src/videogen_hub/pipelines/consisti2v/scripts/animate_autoregress.py new file mode 100644 index 0000000000000000000000000000000000000000..c7937359cc3f602aa564e218d8d66c984338680b --- /dev/null +++ b/src/videogen_hub/pipelines/consisti2v/scripts/animate_autoregress.py @@ -0,0 +1,185 @@ +import argparse +import datetime +import random +import os +import logging +from omegaconf import OmegaConf + +import torch + +import diffusers +from diffusers import AutoencoderKL, DDIMScheduler + +from transformers import CLIPTextModel, CLIPTokenizer + +from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel +from consisti2v.pipelines.pipeline_autoregress_animation import AutoregressiveAnimationPipeline +from consisti2v.utils.util import save_videos_grid +from diffusers.utils.import_utils import is_xformers_available + +def main(args, config): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + diffusers.utils.logging.set_verbosity_info() + + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + savedir = f"{config.output_dir}/{config.output_name}-{time_str}" + os.makedirs(savedir) + + samples = [] + sample_idx = 0 + + ### >>> create validation pipeline >>> ### + if config.pipeline_pretrained_path is None: + noise_scheduler = DDIMScheduler(**OmegaConf.to_container(config.noise_scheduler_kwargs)) + tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True) + text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae", use_safetensors=True) + unet = VideoLDMUNet3DConditionModel.from_pretrained( + config.pretrained_model_path, + subfolder="unet", + variant=config.unet_additional_kwargs['variant'], + temp_pos_embedding=config.unet_additional_kwargs['temp_pos_embedding'], + augment_temporal_attention=config.unet_additional_kwargs['augment_temporal_attention'], + use_temporal=True, + n_frames=config.sampling_kwargs['n_frames'], + n_temp_heads=config.unet_additional_kwargs['n_temp_heads'], + first_frame_condition_mode=config.unet_additional_kwargs['first_frame_condition_mode'], + use_frame_stride_condition=config.unet_additional_kwargs['use_frame_stride_condition'], + use_safetensors=True + ) + + params_unet = [p.numel() for n, p in unet.named_parameters()] + params_vae = [p.numel() for n, p in vae.named_parameters()] + params_text_encoder = [p.numel() for n, p in text_encoder.named_parameters()] + params = params_unet + params_vae + params_text_encoder + print(f"### UNet Parameters: {sum(params) / 1e6} M") + + # 1. unet ckpt + if config.unet_path is not None: + if os.path.isdir(config.unet_path): + unet_dict = VideoLDMUNet3DConditionModel.from_pretrained(config.unet_path) + m, u = unet.load_state_dict(unet_dict.state_dict(), strict=False) + assert len(u) == 0 + del unet_dict + else: + checkpoint_dict = torch.load(config.unet_path, map_location="cpu") + state_dict = checkpoint_dict["state_dict"] if "state_dict" in checkpoint_dict else checkpoint_dict + if config.unet_ckpt_prefix is not None: + state_dict = {k.replace(config.unet_ckpt_prefix, ''): v for k, v in state_dict.items()} + m, u = unet.load_state_dict(state_dict, strict=False) + assert len(u) == 0 + + if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: + unet.enable_xformers_memory_efficient_attention() + + pipeline = AutoregressiveAnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=noise_scheduler) + + else: + pipeline = AutoregressiveAnimationPipeline.from_pretrained(config.pipeline_pretrained_path) + + pipeline.to("cuda") + + # (frameinit) initialize frequency filter for noise reinitialization ------------- + if config.frameinit_kwargs.enable: + pipeline.init_filter( + width = config.sampling_kwargs.width, + height = config.sampling_kwargs.height, + video_length = config.sampling_kwargs.n_frames, + filter_params = config.frameinit_kwargs.filter_params, + ) + # ------------------------------------------------------------------------------- + ### <<< create validation pipeline <<< ### + + if args.prompt is not None: + prompts = [args.prompt] + n_prompts = [args.n_prompt] + first_frame_paths = [args.path_to_first_frame] + random_seeds = [int(args.seed)] if args.seed != "random" else "random" + else: + prompt_config = OmegaConf.load(args.prompt_config) + prompts = prompt_config.prompts + n_prompts = list(prompt_config.n_prompts) * len(prompts) if len(prompt_config.n_prompts) == 1 else prompt_config.n_prompts + first_frame_paths = prompt_config.path_to_first_frames + random_seeds = prompt_config.seeds + + if random_seeds == "random": + random_seeds = [random.randint(0, 1e5) for _ in range(len(prompts))] + else: + random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) + random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds + + config.prompt_kwargs = OmegaConf.create({"random_seeds": [], "prompts": prompts, "n_prompts": n_prompts, "first_frame_paths": first_frame_paths}) + for prompt_idx, (prompt, n_prompt, first_frame_path, random_seed) in enumerate(zip(prompts, n_prompts, first_frame_paths, random_seeds)): + # manually set random seed for reproduction + if random_seed != -1: torch.manual_seed(random_seed) + else: torch.seed() + config.prompt_kwargs.random_seeds.append(torch.initial_seed()) + + print(f"current seed: {torch.initial_seed()}") + print(f"sampling {prompt} ...") + sample = pipeline( + prompt, + negative_prompt = n_prompt, + first_frame_paths = first_frame_path, + num_inference_steps = config.sampling_kwargs.steps, + guidance_scale_txt = config.sampling_kwargs.guidance_scale_txt, + guidance_scale_img = config.sampling_kwargs.guidance_scale_img, + width = config.sampling_kwargs.width, + height = config.sampling_kwargs.height, + video_length = config.sampling_kwargs.n_frames, + noise_sampling_method = config.unet_additional_kwargs['noise_sampling_method'], + noise_alpha = float(config.unet_additional_kwargs['noise_alpha']), + eta = config.sampling_kwargs.ddim_eta, + frame_stride = config.sampling_kwargs.frame_stride, + guidance_rescale = config.sampling_kwargs.guidance_rescale, + num_videos_per_prompt = config.sampling_kwargs.num_videos_per_prompt, + autoregress_steps = config.sampling_kwargs.autoregress_steps, + use_frameinit = config.frameinit_kwargs.enable, + frameinit_noise_level = config.frameinit_kwargs.noise_level, + ).videos + samples.append(sample) + + prompt = "-".join((prompt.replace("/", "").split(" ")[:10])).replace(":", "") + if sample.shape[0] > 1: + for cnt, samp in enumerate(sample): + save_videos_grid(samp.unsqueeze(0), f"{savedir}/sample/{sample_idx}-{cnt + 1}-{prompt}.{args.format}", format=args.format) + else: + save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.{args.format}", format=args.format) + print(f"save to {savedir}/sample/{prompt}.{args.format}") + + sample_idx += 1 + + samples = torch.concat(samples) + save_videos_grid(samples, f"{savedir}/sample.{args.format}", n_rows=4, format=args.format) + + OmegaConf.save(config, f"{savedir}/config.yaml") + + if args.save_model: + pipeline.save_pretrained(f"{savedir}/model") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--inference_config", type=str, default="configs/inference/inference_autoregress.yaml") + parser.add_argument("--prompt", "-p", type=str, default=None) + parser.add_argument("--n_prompt", "-n", type=str, default="") + parser.add_argument("--seed", type=str, default="random") + parser.add_argument("--path_to_first_frame", "-f", type=str, default=None) + parser.add_argument("--prompt_config", type=str, default="configs/prompts/default.yaml") + parser.add_argument("--format", type=str, default="gif", choices=["gif", "mp4"]) + parser.add_argument("--save_model", action="store_true") + parser.add_argument("optional_args", nargs='*', default=[]) + args = parser.parse_args() + + config = OmegaConf.load(args.inference_config) + + if args.optional_args: + modified_config = OmegaConf.from_dotlist(args.optional_args) + config = OmegaConf.merge(config, modified_config) + + main(args, config) diff --git a/src/videogen_hub/pipelines/dynamicrafter/__init__.py b/src/videogen_hub/pipelines/dynamicrafter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a57ec6f72217ea836e4bf7d6bdf9c12ecdb890e --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/__init__.py @@ -0,0 +1,3 @@ +import sys +sys.path.insert(0, './src/videogen_hub/pipelines/dynamicrafter/') +sys.path.insert(0, './src/videogen_hub/pipelines/dynamicrafter/lvdm') \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/configs/__init__.py b/src/videogen_hub/pipelines/dynamicrafter/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/dynamicrafter/configs/inference_1024_v1.0.yaml b/src/videogen_hub/pipelines/dynamicrafter/configs/inference_1024_v1.0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e830ebbc00a2b8b768aa9de22f8b6d1a01130526 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/configs/inference_1024_v1.0.yaml @@ -0,0 +1,103 @@ +model: + target: lvdm.models.ddpm3d.LatentVisualDiffusion + params: + rescale_betas_zero_snr: True + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + timesteps: 1000 + first_stage_key: video + cond_stage_key: caption + cond_stage_trainable: False + conditioning_key: hybrid + image_size: [72, 128] + channels: 4 + scale_by_std: False + scale_factor: 0.18215 + use_ema: False + uncond_type: 'empty_seq' + use_dynamic_rescale: true + base_scale: 0.3 + fps_condition_type: 'fps' + perframe_ae: True + unet_config: + target: lvdm.modules.networks.openaimodel3d.UNetModel + params: + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + dropout: 0.1 + num_head_channels: 64 + transformer_depth: 1 + context_dim: 1024 + use_linear: true + use_checkpoint: True + temporal_conv: True + temporal_attention: True + temporal_selfatt_only: true + use_relative_position: false + use_causal_attention: False + temporal_length: 16 + addition_attention: true + image_cross_attention: true + default_fs: 10 + fs_condition: true + + first_stage_config: + target: lvdm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder + params: + freeze: true + layer: "penultimate" + + img_cond_stage_config: + target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 + params: + freeze: true + + image_proj_stage_config: + target: lvdm.modules.encoders.resampler.Resampler + params: + dim: 1024 + depth: 4 + dim_head: 64 + heads: 12 + num_queries: 16 + embedding_dim: 1280 + output_dim: 1024 + ff_mult: 4 + video_length: 16 + diff --git a/src/videogen_hub/pipelines/dynamicrafter/configs/inference_256_v1.0.yaml b/src/videogen_hub/pipelines/dynamicrafter/configs/inference_256_v1.0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f098796cf2e7487669b93ae3fa8e4ac2087e385 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/configs/inference_256_v1.0.yaml @@ -0,0 +1,98 @@ +model: + target: lvdm.models.ddpm3d.LatentVisualDiffusion + params: + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + timesteps: 1000 + first_stage_key: video + cond_stage_key: caption + cond_stage_trainable: False + conditioning_key: hybrid + image_size: [32, 32] + channels: 4 + scale_by_std: False + scale_factor: 0.18215 + use_ema: False + uncond_type: 'empty_seq' + unet_config: + target: lvdm.modules.networks.openaimodel3d.UNetModel + params: + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + dropout: 0.1 + num_head_channels: 64 + transformer_depth: 1 + context_dim: 1024 + use_linear: true + use_checkpoint: True + temporal_conv: True + temporal_attention: True + temporal_selfatt_only: true + use_relative_position: false + use_causal_attention: False + temporal_length: 16 + addition_attention: true + image_cross_attention: true + image_cross_attention_scale_learnable: true + default_fs: 3 + fs_condition: true + + first_stage_config: + target: lvdm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder + params: + freeze: true + layer: "penultimate" + + img_cond_stage_config: + target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 + params: + freeze: true + + image_proj_stage_config: + target: lvdm.modules.encoders.resampler.Resampler + params: + dim: 1024 + depth: 4 + dim_head: 64 + heads: 12 + num_queries: 16 + embedding_dim: 1280 + output_dim: 1024 + ff_mult: 4 + video_length: 16 + diff --git a/src/videogen_hub/pipelines/dynamicrafter/configs/inference_512_v1.0.yaml b/src/videogen_hub/pipelines/dynamicrafter/configs/inference_512_v1.0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..104c50bc593879a432d52e91ec086d73279f0bfb --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/configs/inference_512_v1.0.yaml @@ -0,0 +1,103 @@ +model: + target: lvdm.models.ddpm3d.LatentVisualDiffusion + params: + rescale_betas_zero_snr: True + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + timesteps: 1000 + first_stage_key: video + cond_stage_key: caption + cond_stage_trainable: False + conditioning_key: hybrid + image_size: [40, 64] + channels: 4 + scale_by_std: False + scale_factor: 0.18215 + use_ema: False + uncond_type: 'empty_seq' + use_dynamic_rescale: true + base_scale: 0.7 + fps_condition_type: 'fps' + perframe_ae: True + unet_config: + target: lvdm.modules.networks.openaimodel3d.UNetModel + params: + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + dropout: 0.1 + num_head_channels: 64 + transformer_depth: 1 + context_dim: 1024 + use_linear: true + use_checkpoint: True + temporal_conv: True + temporal_attention: True + temporal_selfatt_only: true + use_relative_position: false + use_causal_attention: False + temporal_length: 16 + addition_attention: true + image_cross_attention: true + default_fs: 24 + fs_condition: true + + first_stage_config: + target: lvdm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder + params: + freeze: true + layer: "penultimate" + + img_cond_stage_config: + target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 + params: + freeze: true + + image_proj_stage_config: + target: lvdm.modules.encoders.resampler.Resampler + params: + dim: 1024 + depth: 4 + dim_head: 64 + heads: 12 + num_queries: 16 + embedding_dim: 1280 + output_dim: 1024 + ff_mult: 4 + video_length: 16 + diff --git a/src/videogen_hub/pipelines/dynamicrafter/inference.py b/src/videogen_hub/pipelines/dynamicrafter/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d37d8eff41b8129add6ca7192f0942153d2d33da --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/inference.py @@ -0,0 +1,408 @@ +import argparse, os, sys, glob +import datetime, time +from omegaconf import OmegaConf +from tqdm import tqdm +from einops import rearrange, repeat +from collections import OrderedDict + +import torch +import torchvision +import torchvision.transforms as transforms +from pytorch_lightning import seed_everything +from PIL import Image + +sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) +from .lvdm.models.samplers.ddim import DDIMSampler +from .lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond +from .utils import instantiate_from_config + + +def get_filelist(data_dir, postfixes): + patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes] + file_list = [] + for pattern in patterns: + file_list.extend(glob.glob(pattern)) + file_list.sort() + return file_list + + +def load_model_checkpoint(model, ckpt): + state_dict = torch.load(ckpt, map_location="cpu") + if "state_dict" in list(state_dict.keys()): + state_dict = state_dict["state_dict"] + try: + model.load_state_dict(state_dict, strict=True) + except: + ## rename the keys for 256x256 model + new_pl_sd = OrderedDict() + for k, v in state_dict.items(): + new_pl_sd[k] = v + + for k in list(new_pl_sd.keys()): + if "framestride_embed" in k: + new_key = k.replace("framestride_embed", "fps_embedding") + new_pl_sd[new_key] = new_pl_sd[k] + del new_pl_sd[k] + model.load_state_dict(new_pl_sd, strict=True) + else: + # deepspeed + new_pl_sd = OrderedDict() + for key in state_dict['module'].keys(): + new_pl_sd[key[16:]] = state_dict['module'][key] + model.load_state_dict(new_pl_sd) + print('>>> model checkpoint loaded.') + return model + + +def load_prompts(prompt_file): + f = open(prompt_file, 'r') + prompt_list = [] + for idx, line in enumerate(f.readlines()): + l = line.strip() + if len(l) != 0: + prompt_list.append(l) + f.close() + return prompt_list + + +def load_data_prompts(data_dir, video_size=(256, 256), video_frames=16, interp=False): + ## load prompts + prompt_file = get_filelist(data_dir, ['txt']) + assert len(prompt_file) > 0, "Error: found NO prompt file!" + ###### default prompt + default_idx = 0 + default_idx = min(default_idx, len(prompt_file) - 1) + if len(prompt_file) > 1: + print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.") + ## only use the first one (sorted by name) if multiple exist + + ## load video + file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG']) + # assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!" + data_list = [] + filename_list = [] + prompt_list = load_prompts(prompt_file[default_idx]) + n_samples = len(prompt_list) + for idx in range(n_samples): + if interp: + image1 = Image.open(file_list[2 * idx]).convert('RGB') + image2 = Image.open(file_list[2 * idx + 1]).convert('RGB') + frame_tensor = processing_image((image1, image2), video_size, video_frames, True) + _, filename = os.path.split(file_list[idx * 2]) + else: + image = Image.open(file_list[idx]).convert('RGB') + frame_tensor = processing_image(image, video_size, video_frames, False) + _, filename = os.path.split(file_list[idx]) + + data_list.append(frame_tensor) + filename_list.append(filename) + + return filename_list, data_list, prompt_list + + +def processing_image(image, video_size=(256, 256), video_frames=16, interp=False): + transform = transforms.Compose([ + transforms.Resize(min(video_size)), + transforms.CenterCrop(video_size), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) + if interp: + image1, image2 = image + image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w] + image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w] + frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=video_frames // 2) + frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=video_frames // 2) + frame_tensor = torch.cat([frame_tensor1, frame_tensor2], dim=1) + else: + image_tensor = transform(image).unsqueeze(1) # [c,1,h,w] + frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames) + return frame_tensor + + +def save_results(prompt, samples, filename, fakedir, fps=8, loop=False): + filename = filename.split('.')[0] + '.mp4' + prompt = prompt[0] if isinstance(prompt, list) else prompt + + ## save video + videos = [samples] + savedirs = [fakedir] + for idx, video in enumerate(videos): + if video is None: + continue + # b,c,t,h,w + video = video.detach().cpu() + video = torch.clamp(video.float(), -1., 1.) + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w + if loop: + video = video[:-1, ...] + + frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in + video] #[3, 1*h, n*w] + grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w] + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + path = os.path.join(savedirs[idx], filename) + torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', + options={'crf': '10'}) ## crf indicates the quality + + +def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False): + prompt = prompt[0] if isinstance(prompt, list) else prompt + + ## save video + videos = [samples] + savedirs = [fakedir] + for idx, video in enumerate(videos): + if video is None: + continue + # b,c,t,h,w + video = video.detach().cpu() + if loop: # remove the last frame + video = video[:, :, :-1, ...] + video = torch.clamp(video.float(), -1., 1.) + n = video.shape[0] + for i in range(n): + grid = video[i, ...] + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc + path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), + f'{filename.split(".")[0]}_sample{i}.mp4') + torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) + + +def get_latent_z(model, videos): + b, c, t, h, w = videos.shape + x = rearrange(videos, 'b c t h w -> (b t) c h w') + z = model.encode_first_stage(x) + z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) + return z + + +def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \ + unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, + multiple_cond_cfg=False, loop=False, interp=False, timestep_spacing='uniform', + guidance_rescale=0.0, **kwargs): + ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model) + batch_size = noise_shape[0] + fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) + + if not text_input: + prompts = [""] * batch_size + + img = videos[:, :, 0] #bchw + img_emb = model.embedder(img) ## blc + img_emb = model.image_proj_model(img_emb) + + cond_emb = model.get_learned_conditioning(prompts) + cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]} + if model.model.conditioning_key == 'hybrid': + z = get_latent_z(model, videos) # b c t h w + if loop or interp: + img_cat_cond = torch.zeros_like(z) + img_cat_cond[:, :, 0, :, :] = z[:, :, 0, :, :] + img_cat_cond[:, :, -1, :, :] = z[:, :, -1, :, :] + else: + img_cat_cond = z[:, :, :1, :, :] + img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2]) + cond["c_concat"] = [img_cat_cond] # b c 1 h w + + if unconditional_guidance_scale != 1.0: + if model.uncond_type == "empty_seq": + prompts = batch_size * [""] + uc_emb = model.get_learned_conditioning(prompts) + elif model.uncond_type == "zero_embed": + uc_emb = torch.zeros_like(cond_emb) + uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c + uc_img_emb = model.image_proj_model(uc_img_emb) + uc = {"c_crossattn": [torch.cat([uc_emb, uc_img_emb], dim=1)]} + if model.model.conditioning_key == 'hybrid': + uc["c_concat"] = [img_cat_cond] + else: + uc = None + + ## we need one more unconditioning image=yes, text="" + if multiple_cond_cfg and cfg_img != 1.0: + uc_2 = {"c_crossattn": [torch.cat([uc_emb, img_emb], dim=1)]} + if model.model.conditioning_key == 'hybrid': + uc_2["c_concat"] = [img_cat_cond] + kwargs.update({"unconditional_conditioning_img_nonetext": uc_2}) + else: + kwargs.update({"unconditional_conditioning_img_nonetext": None}) + + z0 = None + cond_mask = None + + batch_variants = [] + for _ in range(n_samples): + + if z0 is not None: + cond_z0 = z0.clone() + kwargs.update({"clean_cond": True}) + else: + cond_z0 = None + if ddim_sampler is not None: + samples, _ = ddim_sampler.sample(S=ddim_steps, + conditioning=cond, + batch_size=batch_size, + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + cfg_img=cfg_img, + mask=cond_mask, + x0=cond_z0, + fs=fs, + timestep_spacing=timestep_spacing, + guidance_rescale=guidance_rescale, + **kwargs + ) + + ## reconstruct from latent to pixel space + batch_images = model.decode_first_stage(samples) + batch_variants.append(batch_images) + ## variants, batch, c, t, h, w + batch_variants = torch.stack(batch_variants) + return batch_variants.permute(1, 0, 2, 3, 4, 5) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--savedir", type=str, default=None, help="results saving path") + parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path") + parser.add_argument("--config", type=str, help="config (yaml) path") + parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts") + parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt", ) + parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM", ) + parser.add_argument("--ddim_eta", type=float, default=1.0, + help="eta for ddim sampling (0.0 yields deterministic sampling)", ) + parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one") + parser.add_argument("--height", type=int, default=512, help="image height, in pixel space") + parser.add_argument("--width", type=int, default=512, help="image width, in pixel space") + parser.add_argument("--frame_stride", type=int, default=3, + help="frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)") + parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, + help="prompt classifier-free guidance") + parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything") + parser.add_argument("--video_length", type=int, default=16, help="inference video length") + parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt") + parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not") + parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, + help="use multi-condition cfg or not") + parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning") + parser.add_argument("--timestep_spacing", type=str, default="uniform", + help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.") + parser.add_argument("--guidance_rescale", type=float, default=0.0, + help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)") + parser.add_argument("--perframe_ae", action='store_true', default=False, + help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024") + + ## currently not support looping video and generative frame interpolation + parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not") + parser.add_argument("--interp", action='store_true', default=False, + help="generate generative frame interpolation or not") + return parser + + +class DynamiCrafterPipeline(): + def __init__(self, args): + """ + Initialize the parameters from args + Args: + args: is a list consisting of arguments needed for parser. + e.g. ["--ckpt_path", , ......] + """ + parser = get_parser() + self.args = parser.parse_args(args) + + def run_inference(self, input_image): + """ + Run inference from the input_image. + This input image can either be a tensor or a string as the path of the image file. + Args: + input_image: tensor or string. + + Returns: a tensor representing the generated video of shape (num_frames, channels, height, width) + + """ + args = self.args + seed_everything(args.seed) + ## model config + config = OmegaConf.load(self.args.config) + model_config = config.pop("model", OmegaConf.create()) + + ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set" + model_config['params']['unet_config']['params']['use_checkpoint'] = False + model = instantiate_from_config(model_config) + model = model.cuda() + model.perframe_ae = args.perframe_ae + assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" + model = load_model_checkpoint(model, args.ckpt_path) + model.eval() + + ## run over data + assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" + assert args.bs == 1, "Current implementation only support [batch size = 1]!" + ## latent noise shape + h, w = args.height // 8, args.width // 8 + channels = model.model.diffusion_model.out_channels + n_frames = args.video_length + print(f'Inference with {n_frames} frames') + noise_shape = [args.bs, channels, n_frames, h, w] + + # fakedir = os.path.join(args.savedir, "samples") + # fakedir_separate = os.path.join(args.savedir, "samples_separate") + + # os.makedirs(fakedir, exist_ok=True) + # os.makedirs(fakedir_separate, exist_ok=True) + + ## prompt file setting + + if type(input_image) == str: + args.prompt_dir = input_image + assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!" + filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, + video_size=(args.height, args.width), + video_frames=n_frames, interp=args.interp) + else: + input_pil = (transforms.ToPILImage())(input_image) + frame_tensor = processing_image(input_pil, (args.height, args.width), n_frames, args.interp) + data_list, prompt_list = [frame_tensor], [args.text_input] + + num_samples = len(prompt_list) + # print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples)) + # indices = random.choices(list(range(0, num_samples)), k=samples_per_device) + # indices = list(range(0, num_samples)) + # prompt_list_rank = [prompt_list[i] for i in indices] + # data_list_rank = [data_list[i] for i in indices] + # filename_list_rank = [filename_list[i] for i in indices] + + # start = time.time() + with torch.no_grad(), torch.cuda.amp.autocast(): + # for idx, indice in tqdm(enumerate(range(0, len(prompt_list), args.bs)), desc='Sample Batch'): + prompts = prompt_list[0] + videos = data_list[0] + # filenames = filename_list[0] + if isinstance(videos, list): + videos = torch.stack(videos, dim=0).to("cuda") + else: + videos = videos.unsqueeze(0).to("cuda") + + batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, + args.ddim_eta, \ + args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, + args.text_input, args.multiple_cond_cfg, args.loop, args.interp, + args.timestep_spacing, args.guidance_rescale) + + output = batch_samples.squeeze().permute(1, 0, 2, 3) + return output + # save each example individually + # for nn, samples in enumerate(batch_samples): + # ## samples : [n_samples,c,t,h,w] + # prompt = prompts[nn] + # filename = filenames[nn] + # # save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop) + # save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop) + + # print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds") diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/__init__.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/basics.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/basics.py new file mode 100644 index 0000000000000000000000000000000000000000..7363f677d60ada026090174d9650d044b1768812 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/basics.py @@ -0,0 +1,101 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +import torch.nn as nn +from videogen_hub.pipelines.dynamicrafter.utils import instantiate_from_config + + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def nonlinearity(type='silu'): + if type == 'silu': + return nn.SiLU() + elif type == 'leaky_relu': + return nn.LeakyReLU() + + +class GroupNormSpecific(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels, num_groups=32): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNormSpecific(num_groups, channels) + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/common.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/common.py new file mode 100644 index 0000000000000000000000000000000000000000..55a150b618e275f01d3a59ad9c7579176c4ea1b8 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/common.py @@ -0,0 +1,94 @@ +import math +from inspect import isfunction +import torch +from torch import nn +import torch.distributed as dist + + +def gather_data(data, return_np=True): + ''' gather data from multiple processes to one list ''' + data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] + dist.all_gather(data_list, data) # gather not supported with NCCL + if return_np: + data_list = [data.cpu().numpy() for data in data_list] + return data_list + +def autocast(f): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast(enabled=True, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled()): + return f(*args, **kwargs) + return do_autocast + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val): + return val is not None + +def identity(*args, **kwargs): + return nn.Identity() + +def uniq(arr): + return{el: True for el in arr}.keys() + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + +def shape_to_str(x): + shape_str = "x".join([str(x) for x in x.shape]) + return shape_str + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + +ckpt = torch.utils.checkpoint.checkpoint +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + return ckpt(func, *inputs, use_reentrant=False) + else: + return func(*inputs) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/distributions.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2a82ecace3ce27fb7816ddaf088e179c2d5ffd --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/distributions.py @@ -0,0 +1,95 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self, noise=None): + if noise is None: + noise = torch.randn(self.mean.shape) + + x = self.mean + self.std * noise.to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/ema.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..cd2f8e3115f816b4cac674397238cd8c22de9bc2 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/__init__.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/autoencoder.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..033e8de8268d44c2f60957e162704fb998e4640a --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/autoencoder.py @@ -0,0 +1,219 @@ +import os +from contextlib import contextmanager +import torch +import numpy as np +from einops import rearrange +import torch.nn.functional as F +import pytorch_lightning as pl +from videogen_hub.pipelines.dynamicrafter.lvdm.modules.networks.ae_modules import Encoder, Decoder +from videogen_hub.pipelines.dynamicrafter.lvdm.distributions import DiagonalGaussianDistribution +from videogen_hub.pipelines.dynamicrafter.utils import instantiate_from_config + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + test=False, + logdir=None, + input_dim=4, + test_args=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + self.input_dim = input_dim + self.test = test + self.test_args = test_args + self.logdir = logdir + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + if self.test: + self.init_test() + + def init_test(self,): + self.test = True + save_dir = os.path.join(self.logdir, "test") + if 'ckpt' in self.test_args: + ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}' + self.root = os.path.join(save_dir, ckpt_name) + else: + self.root = save_dir + if 'test_subdir' in self.test_args: + self.root = os.path.join(save_dir, self.test_args.test_subdir) + + self.root_zs = os.path.join(self.root, "zs") + self.root_dec = os.path.join(self.root, "reconstructions") + self.root_inputs = os.path.join(self.root, "inputs") + os.makedirs(self.root, exist_ok=True) + + if self.test_args.save_z: + os.makedirs(self.root_zs, exist_ok=True) + if self.test_args.save_reconstruction: + os.makedirs(self.root_dec, exist_ok=True) + if self.test_args.save_input: + os.makedirs(self.root_inputs, exist_ok=True) + assert(self.test_args is not None) + self.test_maximum = getattr(self.test_args, 'test_maximum', None) + self.count = 0 + self.eval_metrics = {} + self.decodes = [] + self.save_decode_samples = 2048 + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu") + try: + self._cur_epoch = sd['epoch'] + sd = sd["state_dict"] + except: + self._cur_epoch = 'null' + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + # self.load_state_dict(sd, strict=True) + print(f"Restored from {path}") + + def encode(self, x, **kwargs): + + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, **kwargs): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if x.dim() == 5 and self.input_dim == 4: + b,c,t,h,w = x.shape + self.b = b + self.t = t + x = rearrange(x, 'b c t h w -> (b t) c h w') + + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/ddpm3d.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/ddpm3d.py new file mode 100644 index 0000000000000000000000000000000000000000..3e45cc11e803113492c8b9b51ad744b81524c812 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/ddpm3d.py @@ -0,0 +1,762 @@ +""" +wild mixture of +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +from functools import partial +from contextlib import contextmanager +import numpy as np +from tqdm import tqdm +from einops import rearrange, repeat +import logging +mainlogger = logging.getLogger('mainlogger') +import torch +import torch.nn as nn +from torchvision.utils import make_grid +import pytorch_lightning as pl +from videogen_hub.pipelines.dynamicrafter.utils import instantiate_from_config +from videogen_hub.pipelines.dynamicrafter.lvdm.ema import LitEma +from videogen_hub.pipelines.dynamicrafter.lvdm.distributions import DiagonalGaussianDistribution +from videogen_hub.pipelines.dynamicrafter.lvdm.models.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr +from videogen_hub.pipelines.dynamicrafter.lvdm.basics import disabled_train +from videogen_hub.pipelines.dynamicrafter.lvdm.common import ( + extract_into_tensor, + noise_like, + exists, + default +) + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor=None, + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + rescale_betas_zero_snr=False, + ): + super().__init__() + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + mainlogger.info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.channels = channels + self.temporal_length = unet_config.params.temporal_length + self.image_size = image_size # try conv? + if isinstance(self.image_size, int): + self.image_size = [self.image_size, self.image_size] + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + #count_params(self.model, verbose=True) + self.use_ema = use_ema + self.rescale_betas_zero_snr = rescale_betas_zero_snr + if self.use_ema: + self.model_ema = LitEma(self.model) + mainlogger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + if self.rescale_betas_zero_snr: + betas = rescale_zero_terminal_snr(betas) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + + if self.parameterization != 'v': + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + else: + self.register_buffer('sqrt_recip_alphas_cumprod', torch.zeros_like(to_torch(alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.zeros_like(to_torch(alphas_cumprod))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + mainlogger.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + mainlogger.info(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + mainlogger.info("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + mainlogger.info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + mainlogger.info(f"Missing Keys: {missing}") + if len(unexpected) > 0: + mainlogger.info(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_input(self, batch, k): + x = batch[k] + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="caption", + cond_stage_trainable=False, + cond_stage_forward=None, + conditioning_key=None, + uncond_prob=0.2, + uncond_type="empty_seq", + scale_factor=1.0, + scale_by_std=False, + encoder_type="2d", + only_model=False, + noise_strength=0, + use_dynamic_rescale=False, + base_scale=0.7, + turning_step=400, + loop_video=False, + fps_condition_type='fs', + perframe_ae=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + conditioning_key = default(conditioning_key, 'crossattn') + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.noise_strength = noise_strength + self.use_dynamic_rescale = use_dynamic_rescale + self.loop_video = loop_video + self.fps_condition_type = fps_condition_type + self.perframe_ae = perframe_ae + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + + if use_dynamic_rescale: + scale_arr1 = np.linspace(1.0, base_scale, turning_step) + scale_arr2 = np.full(self.num_timesteps, base_scale) + scale_arr = np.concatenate((scale_arr1, scale_arr2)) + to_torch = partial(torch.tensor, dtype=torch.float32) + self.register_buffer('scale_arr', to_torch(scale_arr)) + + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.first_stage_config = first_stage_config + self.cond_stage_config = cond_stage_config + self.clip_denoised = False + + self.cond_stage_forward = cond_stage_forward + self.encoder_type = encoder_type + assert(encoder_type in ["2d", "3d"]) + self.uncond_prob = uncond_prob + self.classifier_free_guidance = True if uncond_prob > 0 else False + assert(uncond_type in ["zero_embed", "empty_seq"]) + self.uncond_type = uncond_type + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model) + self.restarted_from_ckpt = True + + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + model = instantiate_from_config(config) + self.cond_stage_model = model + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def get_first_stage_encoding(self, encoder_posterior, noise=None): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample(noise=noise) + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + @torch.no_grad() + def encode_first_stage(self, x): + if self.encoder_type == "2d" and x.dim() == 5: + b, _, t, _, _ = x.shape + x = rearrange(x, 'b c t h w -> (b t) c h w') + reshape_back = True + else: + reshape_back = False + + ## consume more GPU memory but faster + if not self.perframe_ae: + encoder_posterior = self.first_stage_model.encode(x) + results = self.get_first_stage_encoding(encoder_posterior).detach() + else: ## consume less GPU memory but slower + results = [] + for index in range(x.shape[0]): + frame_batch = self.first_stage_model.encode(x[index:index+1,:,:,:]) + frame_result = self.get_first_stage_encoding(frame_batch).detach() + results.append(frame_result) + results = torch.cat(results, dim=0) + + if reshape_back: + results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t) + + return results + + def decode_core(self, z, **kwargs): + if self.encoder_type == "2d" and z.dim() == 5: + b, _, t, _, _ = z.shape + z = rearrange(z, 'b c t h w -> (b t) c h w') + reshape_back = True + else: + reshape_back = False + + if not self.perframe_ae: + z = 1. / self.scale_factor * z + results = self.first_stage_model.decode(z, **kwargs) + else: + results = [] + for index in range(z.shape[0]): + frame_z = 1. / self.scale_factor * z[index:index+1,:,:,:] + frame_result = self.first_stage_model.decode(frame_z, **kwargs) + results.append(frame_result) + results = torch.cat(results, dim=0) + + if reshape_back: + results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t) + return results + + @torch.no_grad() + def decode_first_stage(self, z, **kwargs): + return self.decode_core(z, **kwargs) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, **kwargs): + return self.decode_core(z, **kwargs) + + def forward(self, x, c, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.use_dynamic_rescale: + x = x * extract_into_tensor(self.scale_arr, t, x.shape) + return self.p_losses(x, c, t, **kwargs) + + def apply_model(self, x_noisy, t, cond, **kwargs): + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond, **kwargs) + + if isinstance(x_recon, tuple): + return x_recon[0] + else: + return x_recon + + def _get_denoise_row_from_list(self, samples, desc=''): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device))) + n_log_timesteps = len(denoise_row) + + denoise_row = torch.stack(denoise_row) # n_log_timesteps, b, C, H, W + + if denoise_row.dim() == 5: + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_log_timesteps) + elif denoise_row.dim() == 6: + # video, grid_size=[n_log_timesteps*bs, t] + video_length = denoise_row.shape[3] + denoise_grid = rearrange(denoise_row, 'n b c t h w -> b n c t h w') + denoise_grid = rearrange(denoise_grid, 'b n c t h w -> (b n) c t h w') + denoise_grid = rearrange(denoise_grid, 'n c t h w -> (n t) c h w') + denoise_grid = make_grid(denoise_grid, nrow=video_length) + else: + raise ValueError + + return denoise_grid + + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_x0=False, score_corrector=None, corrector_kwargs=None, **kwargs): + t_in = t + model_out = self.apply_model(x, t_in, c, **kwargs) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + + if return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, return_x0=False, \ + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, **kwargs): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, return_x0=return_x0, \ + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, **kwargs) + if return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, x_T=None, verbose=True, callback=None, \ + timesteps=None, mask=None, x0=None, img_callback=None, start_T=None, log_every_t=None, **kwargs): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + # sample an initial noise + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + if start_T is not None: + timesteps = min(timesteps, start_T) + + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, **kwargs) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + +class LatentVisualDiffusion(LatentDiffusion): + def __init__(self, img_cond_stage_config, image_proj_stage_config, freeze_embedder=True, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_embedder(img_cond_stage_config, freeze_embedder) + self.image_proj_model = instantiate_from_config(image_proj_stage_config) + + def _init_embedder(self, config, freeze=True): + embedder = instantiate_from_config(config) + if freeze: + self.embedder = embedder.eval() + self.embedder.train = disabled_train + for param in self.embedder.parameters(): + param.requires_grad = False + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, + c_adm=None, s=None, mask=None, **kwargs): + # temporal_context = fps is foNone + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t, **kwargs) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, **kwargs) + elif self.conditioning_key == 'hybrid': + ## it is just right [b,c,t,h,w]: concatenate in channel dim + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, **kwargs) + elif self.conditioning_key == 'resblockcond': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm, **kwargs) + elif self.conditioning_key == 'hybrid-time': + assert s is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, s=s) + elif self.conditioning_key == 'concat-time-mask': + # assert s is not None + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t, context=None, s=s, mask=mask) + elif self.conditioning_key == 'concat-adm-mask': + # assert s is not None + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + out = self.diffusion_model(xc, t, context=None, y=s, mask=mask) + elif self.conditioning_key == 'hybrid-adm-mask': + cc = torch.cat(c_crossattn, 1) + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + out = self.diffusion_model(xc, t, context=cc, y=s, mask=mask) + elif self.conditioning_key == 'hybrid-time-adm': # adm means y, e.g., class index + # assert s is not None + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, s=s, y=c_adm) + elif self.conditioning_key == 'crossattn-adm': + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + else: + raise NotImplementedError() + + return out \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/samplers/__init__.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/samplers/ddim.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/samplers/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..707112ff7cd32d54adb7ea07db7209fd7b21099b --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/samplers/ddim.py @@ -0,0 +1,317 @@ +import numpy as np +from tqdm import tqdm +import torch +from videogen_hub.pipelines.dynamicrafter.lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg +from videogen_hub.pipelines.dynamicrafter.lvdm.common import noise_like +from videogen_hub.pipelines.dynamicrafter.lvdm.common import extract_into_tensor +import copy + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.counter = 0 + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + if self.model.use_dynamic_rescale: + self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] + self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + schedule_verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + precision=None, + fs=None, + timestep_spacing='uniform', #uniform_trailing for starting from last timestep + guidance_rescale=0.0, + **kwargs + ): + + # check condition bs + if conditioning is not None: + if isinstance(conditioning, dict): + try: + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + except: + cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] + + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) + + # make shape + if len(shape) == 3: + C, H, W = shape + size = (batch_size, C, H, W) + elif len(shape) == 4: + C, T, H, W = shape + size = (batch_size, C, T, H, W) + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + verbose=verbose, + precision=precision, + fs=fs, + guidance_rescale=guidance_rescale, + **kwargs) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, + **kwargs): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + if precision is not None: + if precision == 16: + img = img.to(dtype=torch.float16) + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + if verbose: + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + else: + iterator = time_range + + clean_cond = kwargs.pop("clean_cond", False) + + # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + ## use mask to blend noised original latent (img_orig) & new sampled latent (img) + if mask is not None: + assert x0 is not None + if clean_cond: + img_orig = x0 + else: + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img # keep original & modify use img + + + + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, + **kwargs) + + + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + uc_type=None, conditional_guidance_scale_temporal=None,mask=None,x0=None,guidance_rescale=0.0,**kwargs): + b, *_, device = *x.shape, x.device + if x.dim() == 5: + is_video = True + else: + is_video = False + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser + else: + ### do_classifier_free_guidance + if isinstance(c, torch.Tensor) or isinstance(c, dict): + e_t_cond = self.model.apply_model(x, t, c, **kwargs) + e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) + else: + raise NotImplementedError + + model_output = e_t_uncond + unconditional_guidance_scale * (e_t_cond - e_t_uncond) + + if guidance_rescale > 0.0: + model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + # sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + + if is_video: + size = (b, 1, 1, 1, 1) + else: + size = (b, 1, 1, 1) + a_t = torch.full(size, alphas[index], device=device) + a_prev = torch.full(size, alphas_prev[index], device=device) + sigma_t = torch.full(size, sigmas[index], device=device) + sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if self.model.use_dynamic_rescale: + scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) + prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) + rescale = (prev_scale_t / scale_t) + pred_x0 *= rescale + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + + return x_prev, pred_x0 + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/samplers/ddim_multiplecond.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/samplers/ddim_multiplecond.py new file mode 100644 index 0000000000000000000000000000000000000000..a5503f68f61ccd5047900c316698d21faa140111 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/samplers/ddim_multiplecond.py @@ -0,0 +1,323 @@ +import numpy as np +from tqdm import tqdm +import torch +from videogen_hub.pipelines.dynamicrafter.lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg +from videogen_hub.pipelines.dynamicrafter.lvdm.common import noise_like +from videogen_hub.pipelines.dynamicrafter.lvdm.common import extract_into_tensor +import copy + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.counter = 0 + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + if self.model.use_dynamic_rescale: + self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] + self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + schedule_verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + precision=None, + fs=None, + timestep_spacing='uniform', #uniform_trailing for starting from last timestep + guidance_rescale=0.0, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + + # check condition bs + if conditioning is not None: + if isinstance(conditioning, dict): + try: + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + except: + cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] + + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # print('==> timestep_spacing: ', timestep_spacing, guidance_rescale) + self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) + + # make shape + if len(shape) == 3: + C, H, W = shape + size = (batch_size, C, H, W) + elif len(shape) == 4: + C, T, H, W = shape + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + verbose=verbose, + precision=precision, + fs=fs, + guidance_rescale=guidance_rescale, + **kwargs) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, + **kwargs): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + if precision is not None: + if precision == 16: + img = img.to(dtype=torch.float16) + + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + if verbose: + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + else: + iterator = time_range + + clean_cond = kwargs.pop("clean_cond", False) + + # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + ## use mask to blend noised original latent (img_orig) & new sampled latent (img) + if mask is not None: + assert x0 is not None + if clean_cond: + img_orig = x0 + else: + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img # keep original & modify use img + + + + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, + **kwargs) + + + + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + uc_type=None, cfg_img=None,mask=None,x0=None,guidance_rescale=0.0, **kwargs): + b, *_, device = *x.shape, x.device + if x.dim() == 5: + is_video = True + else: + is_video = False + if cfg_img is None: + cfg_img = unconditional_guidance_scale + + unconditional_conditioning_img_nonetext = kwargs['unconditional_conditioning_img_nonetext'] + + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser + else: + ### with unconditional condition + e_t_cond = self.model.apply_model(x, t, c, **kwargs) + e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) + e_t_uncond_img = self.model.apply_model(x, t, unconditional_conditioning_img_nonetext, **kwargs) + # text cfg + model_output = e_t_uncond + cfg_img * (e_t_uncond_img - e_t_uncond) + unconditional_guidance_scale * (e_t_cond - e_t_uncond_img) + if guidance_rescale > 0.0: + model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + + if is_video: + size = (b, 1, 1, 1, 1) + else: + size = (b, 1, 1, 1) + a_t = torch.full(size, alphas[index], device=device) + a_prev = torch.full(size, alphas_prev[index], device=device) + sigma_t = torch.full(size, sigmas[index], device=device) + sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if self.model.use_dynamic_rescale: + scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) + prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) + rescale = (prev_scale_t / scale_t) + pred_x0 *= rescale + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + + return x_prev, pred_x0 + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/utils_diffusion.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/utils_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..30043d3cbf91ce08f9206c00e2df6955221e8cb5 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/models/utils_diffusion.py @@ -0,0 +1,158 @@ +import math +import numpy as np +import torch +import torch.nn.functional as F +from einops import repeat + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + steps_out = ddim_timesteps + 1 + elif ddim_discr_method == 'uniform_trailing': + c = num_ddpm_timesteps / num_ddim_timesteps + ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0, -c))).astype(np.int64) + steps_out = ddim_timesteps - 1 + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + steps_out = ddim_timesteps + 1 + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + # steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}') + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + Args: + betas (`numpy.ndarray`): + the betas that the scheduler is being initialized with. + + Returns: + `numpy.ndarray`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_bar_sqrt = np.sqrt(alphas_cumprod) + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = np.concatenate([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/__init__.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/attention.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..176885d9ff6f5675c413523a38b78845ce04bd97 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/attention.py @@ -0,0 +1,514 @@ +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange, repeat +from functools import partial +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False +from lvdm.common import ( + checkpoint, + exists, + default, +) +from lvdm.basics import zero_module + + +class RelativePosition(nn.Module): + """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """ + + def __init__(self, num_units, max_relative_position): + super().__init__() + self.num_units = num_units + self.max_relative_position = max_relative_position + self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)) + nn.init.xavier_uniform_(self.embeddings_table) + + def forward(self, length_q, length_k): + device = self.embeddings_table.device + range_vec_q = torch.arange(length_q, device=device) + range_vec_k = torch.arange(length_k, device=device) + distance_mat = range_vec_k[None, :] - range_vec_q[:, None] + distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) + final_mat = distance_mat_clipped + self.max_relative_position + final_mat = final_mat.long() + embeddings = self.embeddings_table[final_mat] + return embeddings + + +class CrossAttention(nn.Module): + + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., + relative_position=False, temporal_length=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + self.dim_head = dim_head + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + self.relative_position = relative_position + if self.relative_position: + assert(temporal_length is not None) + self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) + self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) + else: + ## only used for spatial attention, while NOT for temporal attention + if XFORMERS_IS_AVAILBLE and temporal_length is None: + self.forward = self.efficient_forward + + self.video_length = video_length + self.image_cross_attention = image_cross_attention + self.image_cross_attention_scale = image_cross_attention_scale + self.text_context_len = text_context_len + self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable + if self.image_cross_attention: + self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) + if image_cross_attention_scale_learnable: + self.register_parameter('alpha', nn.Parameter(torch.tensor(0.)) ) + + + def forward(self, x, context=None, mask=None): + spatial_self_attn = (context is None) + k_ip, v_ip, out_ip = None, None, None + + h = self.heads + q = self.to_q(x) + context = default(context, x) + + if self.image_cross_attention and not spatial_self_attn: + context, context_image = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] + k = self.to_k(context) + v = self.to_v(context) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + else: + if not spatial_self_attn: + context = context[:,:self.text_context_len,:] + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + if self.relative_position: + len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1] + k2 = self.relative_position_k(len_q, len_k) + sim2 = einsum('b t d, t s d -> b t s', q, k2) * self.scale # TODO check + sim += sim2 + del k + + if exists(mask): + ## feasible for causal attention mask only + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b i j -> (b h) i j', h=h) + sim.masked_fill_(~(mask>0.5), max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + if self.relative_position: + v2 = self.relative_position_v(len_q, len_v) + out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check + out += out2 + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + + ## for image cross-attention + if k_ip is not None: + k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_ip, v_ip)) + sim_ip = torch.einsum('b i d, b j d -> b i j', q, k_ip) * self.scale + del k_ip + sim_ip = sim_ip.softmax(dim=-1) + out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) + out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) + + + if out_ip is not None: + if self.image_cross_attention_scale_learnable: + out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha)+1) + else: + out = out + self.image_cross_attention_scale * out_ip + + return self.to_out(out) + + def efficient_forward(self, x, context=None, mask=None): + spatial_self_attn = (context is None) + k_ip, v_ip, out_ip = None, None, None + + q = self.to_q(x) + context = default(context, x) + + if self.image_cross_attention and not spatial_self_attn: + context, context_image = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] + k = self.to_k(context) + v = self.to_v(context) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + else: + if not spatial_self_attn: + context = context[:,:self.text_context_len,:] + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) + + ## for image cross-attention + if k_ip is not None: + k_ip, v_ip = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (k_ip, v_ip), + ) + out_ip = xformers.ops.memory_efficient_attention(q, k_ip, v_ip, attn_bias=None, op=None) + out_ip = ( + out_ip.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + if out_ip is not None: + if self.image_cross_attention_scale_learnable: + out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha)+1) + else: + out = out + self.image_cross_attention_scale * out_ip + + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False, attention_cls=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77): + super().__init__() + attn_cls = CrossAttention if attention_cls is None else attention_cls + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, video_length=video_length, image_cross_attention=image_cross_attention, image_cross_attention_scale=image_cross_attention_scale, image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,text_context_len=text_context_len) + self.image_cross_attention = image_cross_attention + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + + def forward(self, x, context=None, mask=None, **kwargs): + ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments + input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments + if context is not None: + input_tuple = (x, context) + if mask is not None: + forward_mask = partial(self._forward, mask=mask) + return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint) + return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint) + + + def _forward(self, x, context=None, mask=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data in spatial axis. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, + use_checkpoint=True, disable_self_attn=False, use_linear=False, video_length=None, + image_cross_attention=False, image_cross_attention_scale_learnable=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + attention_cls = None + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + attention_cls=attention_cls, + video_length=video_length, + image_cross_attention=image_cross_attention, + image_cross_attention_scale_learnable=image_cross_attention_scale_learnable, + ) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + else: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + + def forward(self, x, context=None, **kwargs): + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context, **kwargs) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data in temporal axis. + First, reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, + use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, causal_block_size=1, + relative_position=False, temporal_length=None): + super().__init__() + self.only_self_att = only_self_att + self.relative_position = relative_position + self.causal_attention = causal_attention + self.causal_block_size = causal_block_size + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + if relative_position: + assert(temporal_length is not None) + attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length) + else: + attention_cls = partial(CrossAttention, temporal_length=temporal_length) + if self.causal_attention: + assert(temporal_length is not None) + self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length])) + + if self.only_self_att: + context_dim = None + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + attention_cls=attention_cls, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + else: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + b, c, t, h, w = x.shape + x_in = x + x = self.norm(x) + x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'bhw c t -> bhw t c').contiguous() + if self.use_linear: + x = self.proj_in(x) + + temp_mask = None + if self.causal_attention: + # slice the from mask map + temp_mask = self.mask[:,:t,:t].to(x.device) + + if temp_mask is not None: + mask = temp_mask.to(x.device) + mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w) + else: + mask = None + + if self.only_self_att: + ## note: if no context is given, cross-attention defaults to self-attention + for i, block in enumerate(self.transformer_blocks): + x = block(x, mask=mask) + x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() + context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous() + for i, block in enumerate(self.transformer_blocks): + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_j = repeat( + context[j], + 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous() + ## note: causal mask will not applied in cross-attention case + x[j] = block(x[j], context=context_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous() + + return x + x_in + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/encoders/__init__.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/encoders/condition.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/encoders/condition.py new file mode 100644 index 0000000000000000000000000000000000000000..14b41e12b738dfa33392bd322e19da7f28f3d028 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/encoders/condition.py @@ -0,0 +1,389 @@ +import torch +import torch.nn as nn +import kornia +import open_clip +from torch.utils.checkpoint import checkpoint +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel +from videogen_hub.pipelines.dynamicrafter.lvdm.common import autocast +from videogen_hub.pipelines.dynamicrafter.utils import count_params + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, + freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class ClipImageEmbedder(nn.Module): + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=True, + ucg_rate=0. + ): + super().__init__() + from clip import load as load_clip + self.model, _ = load_clip(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # re-normalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x, no_dropout=False): + # x is assumed to be in range [-1,1] + out = self.model.encode_image(self.preprocess(x)) + out = out.to(x.dtype) + if self.ucg_rate > 0. and not no_dropout: + out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out + return out + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + # "pooled", + "last", + "penultimate" + ] + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) ## all clip models use 77 as context length + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="pooled", antialias=True, ucg_rate=0.): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), + pretrained=version, ) + del model.transformer + self.model = model + # self.mapper = torch.nn.Linear(1280, 1024) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + @autocast + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + if self.ucg_rate > 0. and not no_dropout: + z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z + return z + + def encode_with_vision_transformer(self, img): + img = self.preprocess(img) + x = self.model.visual(img) + return x + + def encode(self, text): + return self(text) + +class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", + freeze=True, layer="pooled", antialias=True): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), + pretrained=version, ) + del model.transformer + self.model = model + self.device = device + + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + def forward(self, image, no_dropout=False): + ## image: b c h w + z = self.encode_with_vision_transformer(image) + return z + + def encode_with_vision_transformer(self, x): + x = self.preprocess(x) + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.model.visual.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1) + x = self.model.visual.patchnorm_pre_ln(x) + x = self.model.visual.conv1(x) + else: + x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.model.visual.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.model.visual.patch_dropout(x) + x = self.model.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.model.visual.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + return x + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/encoders/resampler.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/encoders/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..0c30c58a9a4530f82bf245355fde564553cc7893 --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/encoders/resampler.py @@ -0,0 +1,145 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py +# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py +import math +import torch +import torch.nn as nn + + +class ImageProjModel(nn.Module): + """Projection Model""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + #embeds = image_embeds + embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) + clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + video_length=None, # using frame-wise version or not + ): + super().__init__() + ## queries for a single frame / image + self.num_queries = num_queries + self.video_length = video_length + + ## queries for each frame + if video_length is not None: + num_queries = num_queries * video_length + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + latents = self.norm_out(latents) # B L C or B (T L) C + + return latents \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/networks/__init__.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/networks/ae_modules.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/networks/ae_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..6cdb28e92b908989d4fcdf5cec8209f81d6a245b --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/networks/ae_modules.py @@ -0,0 +1,844 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import numpy as np +import torch.nn as nn +from einops import rearrange +from videogen_hub.pipelines.dynamicrafter.utils import instantiate_from_config +from videogen_hub.pipelines.dynamicrafter.lvdm.modules.attention import LinearAttention + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) # bcl + q = q.permute(0,2,1) # bcl -> blc l=hw + k = k.reshape(b,c,h*w) # bcl + + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + #print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # print(f'encoder-input={x.shape}') + # downsampling + hs = [self.conv_in(x)] + # print(f'encoder-conv in feat={hs[0].shape}') + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + # print(f'encoder-down feat={h.shape}') + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + # print(f'encoder-downsample (input)={hs[-1].shape}') + hs.append(self.down[i_level].downsample(hs[-1])) + # print(f'encoder-downsample (output)={hs[-1].shape}') + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + # print(f'encoder-mid1 feat={h.shape}') + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # print(f'encoder-mid2 feat={h.shape}') + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # print(f'end feat={h.shape}') + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("AE working on z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # print(f'decoder-input={z.shape}') + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + # print(f'decoder-conv in feat={h.shape}') + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # print(f'decoder-mid feat={h.shape}') + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + # print(f'decoder-up feat={h.shape}') + if i_level != 0: + h = self.up[i_level].upsample(h) + # print(f'decoder-upsample feat={h.shape}') + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # print(f'decoder-conv_out feat={h.shape}') + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/networks/openaimodel3d.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/networks/openaimodel3d.py new file mode 100644 index 0000000000000000000000000000000000000000..49245da8ff896d938cf13c6cf6cb23548383c6dc --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/networks/openaimodel3d.py @@ -0,0 +1,603 @@ +from functools import partial +from abc import abstractmethod +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +from lvdm.models.utils_diffusion import timestep_embedding +from lvdm.common import checkpoint +from lvdm.basics import ( + zero_module, + conv_nd, + linear, + avg_pool_nd, + normalization +) +from lvdm.modules.attention import SpatialTransformer, TemporalTransformer + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, batch_size=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb, batch_size=batch_size) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + elif isinstance(layer, TemporalTransformer): + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size) + x = layer(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + else: + x = layer(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + :param use_temporal_conv: if True, use the temporal convolution. + :param use_image_dataset: if True, the temporal parameters will not be optimized. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + use_conv=False, + up=False, + down=False, + use_temporal_conv=False, + tempspatial_aware=False + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock( + self.out_channels, + self.out_channels, + dropout=0.1, + spatial_aware=tempspatial_aware + ) + + def forward(self, x, emb, batch_size=None): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + input_tuple = (x, emb) + if batch_size: + forward_batchsize = partial(self._forward, batch_size=batch_size) + return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint) + return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint) + + def _forward(self, x, emb, batch_size=None): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv and batch_size: + h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c t h w -> (b t) c h w') + return h + + +class TemporalConvBlock(nn.Module): + """ + Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py + """ + def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False): + super(TemporalConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + self.in_channels = in_channels + self.out_channels = out_channels + th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1) + th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0) + tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3) + tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1) + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_channels), nn.SiLU(), + nn.Conv3d(in_channels, out_channels, th_kernel_shape, padding=th_padding_shape)) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape)) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_channels, in_channels, th_kernel_shape, padding=th_padding_shape)) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape)) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + return identity + x + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: in_channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__(self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0.0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + context_dim=None, + use_scale_shift_norm=False, + resblock_updown=False, + num_heads=-1, + num_head_channels=-1, + transformer_depth=1, + use_linear=False, + use_checkpoint=False, + temporal_conv=False, + tempspatial_aware=False, + temporal_attention=True, + use_relative_position=True, + use_causal_attention=False, + temporal_length=None, + use_fp16=False, + addition_attention=False, + temporal_selfatt_only=True, + image_cross_attention=False, + image_cross_attention_scale_learnable=False, + default_fs=4, + fs_condition=False, + ): + super(UNetModel, self).__init__() + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.temporal_attention = temporal_attention + time_embed_dim = model_channels * 4 + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + temporal_self_att_only = True + self.addition_attention = addition_attention + self.temporal_length = temporal_length + self.image_cross_attention = image_cross_attention + self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable + self.default_fs = default_fs + self.fs_condition = fs_condition + + ## Time embedding blocks + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + if fs_condition: + self.fps_embedding = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + ## Input Block + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1)) + ] + ) + if self.addition_attention: + self.init_attn=TimestepEmbedSequential( + TemporalTransformer( + model_channels, + n_heads=8, + d_head=num_head_channels, + depth=transformer_depth, + context_dim=context_dim, + use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only, + causal_attention=False, relative_position=use_relative_position, + temporal_length=temporal_length)) + + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock(ch, time_embed_dim, dropout, + out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers.append( + SpatialTransformer(ch, num_heads, dim_head, + depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, disable_self_attn=False, + video_length=temporal_length, image_cross_attention=self.image_cross_attention, + image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable, + ) + ) + if self.temporal_attention: + layers.append( + TemporalTransformer(ch, num_heads, dim_head, + depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only, + causal_attention=use_causal_attention, relative_position=use_relative_position, + temporal_length=temporal_length + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock(ch, time_embed_dim, dropout, + out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers = [ + ResBlock(ch, time_embed_dim, dropout, + dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv + ), + SpatialTransformer(ch, num_heads, dim_head, + depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length, + image_cross_attention=self.image_cross_attention,image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable + ) + ] + if self.temporal_attention: + layers.append( + TemporalTransformer(ch, num_heads, dim_head, + depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only, + causal_attention=use_causal_attention, relative_position=use_relative_position, + temporal_length=temporal_length + ) + ) + layers.append( + ResBlock(ch, time_embed_dim, dropout, + dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv + ) + ) + + ## Middle Block + self.middle_block = TimestepEmbedSequential(*layers) + + ## Output Block + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock(ch + ich, time_embed_dim, dropout, + out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers.append( + SpatialTransformer(ch, num_heads, dim_head, + depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length, + image_cross_attention=self.image_cross_attention,image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable + ) + ) + if self.temporal_attention: + layers.append( + TemporalTransformer(ch, num_heads, dim_head, + depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only, + causal_attention=use_causal_attention, relative_position=use_relative_position, + temporal_length=temporal_length + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock(ch, time_embed_dim, dropout, + out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def forward(self, x, timesteps, context=None, features_adapter=None, fs=None, **kwargs): + b,_,t,_,_ = x.shape + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype) + emb = self.time_embed(t_emb) + + ## repeat t times for context [(b t) 77 768] & time embedding + ## check if we use per-frame image conditioning + _, l_context, _ = context.shape + if l_context == 77 + t*16: ## !!! HARD CODE here + context_text, context_img = context[:,:77,:], context[:,77:,:] + context_text = context_text.repeat_interleave(repeats=t, dim=0) + context_img = rearrange(context_img, 'b (t l) c -> (b t) l c', t=t) + context = torch.cat([context_text, context_img], dim=1) + else: + context = context.repeat_interleave(repeats=t, dim=0) + emb = emb.repeat_interleave(repeats=t, dim=0) + + ## always in shape (b t) c h w, except for temporal layer + x = rearrange(x, 'b c t h w -> (b t) c h w') + + ## combine emb + if self.fs_condition: + if fs is None: + fs = torch.tensor( + [self.default_fs] * b, dtype=torch.long, device=x.device) + fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype) + + fs_embed = self.fps_embedding(fs_emb) + fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0) + emb = emb + fs_embed + + h = x.type(self.dtype) + adapter_idx = 0 + hs = [] + for id, module in enumerate(self.input_blocks): + h = module(h, emb, context=context, batch_size=b) + if id ==0 and self.addition_attention: + h = self.init_attn(h, emb, context=context, batch_size=b) + ## plug-in adapter features + if ((id+1)%3 == 0) and features_adapter is not None: + h = h + features_adapter[adapter_idx] + adapter_idx += 1 + hs.append(h) + if features_adapter is not None: + assert len(features_adapter)==adapter_idx, 'Wrong features_adapter' + + h = self.middle_block(h, emb, context=context, batch_size=b) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context=context, batch_size=b) + h = h.type(x.dtype) + y = self.out(h) + + # reshape back to (b c t h w) + y = rearrange(y, '(b t) c h w -> b c t h w', b=b) + return y \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/x_transformer.py b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5321012f860a8fb06850c1ddf495db934addecae --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/lvdm/modules/x_transformer.py @@ -0,0 +1,639 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat +import torch +from torch import nn, einsum +import torch.nn.functional as F + +# constants +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out \ No newline at end of file diff --git a/src/videogen_hub/pipelines/dynamicrafter/utils.py b/src/videogen_hub/pipelines/dynamicrafter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c73b93e006c4250161b427e4d1fff512ca046f7c --- /dev/null +++ b/src/videogen_hub/pipelines/dynamicrafter/utils.py @@ -0,0 +1,77 @@ +import importlib +import numpy as np +import cv2 +import torch +import torch.distributed as dist + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def check_istarget(name, para_list): + """ + name: full name of source para + para_list: partial name of target para + """ + istarget=False + for para in para_list: + if para in name: + return True + return istarget + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def load_npz_from_dir(data_dir): + data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] + data = np.concatenate(data, axis=0) + return data + + +def load_npz_from_paths(data_paths): + data = [np.load(data_path)['arr_0'] for data_path in data_paths] + data = np.concatenate(data, axis=0) + return data + + +def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): + h, w = image.shape[:2] + if resize_short_edge is not None: + k = resize_short_edge / min(h, w) + else: + k = max_resolution / (h * w) + k = k**0.5 + h = int(np.round(h * k / 64)) * 64 + w = int(np.round(w * k / 64)) * 64 + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) + return image + + +def setup_dist(args): + if dist.is_initialized(): + return + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group( + 'nccl', + init_method='env://' + ) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/__init__.py b/src/videogen_hub/pipelines/lavie/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/__init__.py @@ -0,0 +1 @@ + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/__init__.py @@ -0,0 +1 @@ + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/__init__.py @@ -0,0 +1 @@ + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/configs/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/configs/sample.yaml b/src/videogen_hub/pipelines/lavie/lavie_src/base/configs/sample.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8ca2403c99e432d6f8a05e945bc144cd0fbb6d8 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/configs/sample.yaml @@ -0,0 +1,33 @@ +# path: +ckpt_path: "../pretrained_models/lavie_base.pt" +output_folder: "../res/base/" +pretrained_path: "../pretrained_models" + +# model config: +model: UNet +video_length: 16 +image_size: [320, 512] + +# beta schedule +beta_start: 0.0001 +beta_end: 0.02 +beta_schedule: "linear" + +# model speedup +use_compile: False +use_fp16: True + +# sample config: +seed: #400 +run_time: 0 +guidance_scale: 7.5 +sample_method: 'ddpm' +num_sampling_steps: 50 +text_prompt: [ + 'a teddy bear walking on the street, 2k, high quality', + 'a panda taking a selfie, 2k, high quality', + 'a polar bear playing drum kit in NYC Times Square, 4k, high resolution', + 'jungle river at sunset, ultra quality', + 'a shark swimming in clear Carribean ocean, 2k, high quality', + 'a Corgi walking in the park at sunrise, oil painting style' + ] diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/download.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/download.py new file mode 100644 index 0000000000000000000000000000000000000000..26de159d098d2086a35f1504477eb5c01a35f540 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/download.py @@ -0,0 +1,18 @@ +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import os + + +def find_model(model_name): + """ + Finds a pre-trained model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + if "ema" in checkpoint: # supports checkpoints from train.py + print('Ema existing!') + checkpoint = checkpoint["ema"] + return checkpoint diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/models/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f93daa1bf5131bd1251a44e99aeeb0d127ea71f1 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/__init__.py @@ -0,0 +1,33 @@ +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from .unet import UNet3DConditionModel +from torch.optim.lr_scheduler import LambdaLR + +def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit + from torch.optim.lr_scheduler import LambdaLR + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'warmup': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + +def get_models(args, sd_path): + + if 'UNet' in args.model: + return UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet") + else: + raise '{} Model Not Supported!'.format(args.model) + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/models/attention.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5b0ff4a7fb9a8d667a9b502017208293ab2609 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/attention.py @@ -0,0 +1,707 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from dataclasses import dataclass +from typing import Optional + +import math +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm +from rotary_embedding_torch import RotaryEmbedding +from typing import Callable, Optional +from einops import rearrange, repeat + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def exists(x): + return x is not None + + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.rotary_emb = RotaryEmbedding(min(32, dim_head)) + + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + + # print('before reshpape query shape', query.shape) + dim = query.shape[-1] + if not self.use_relative_position: + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # print('after reshape query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if not self.use_relative_position: + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous() + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length, + use_image_num=use_image_num, + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous() + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_first_frame = use_first_frame + + # Spatial-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Text Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Temp + self.attn_temp = TemporalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + rotary_emb=rotary_emb, + ) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None): + + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Temporal Attention + if self.training: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous() + hidden_states_video = hidden_states[:, :video_length, :] + hidden_states_image = hidden_states[:, video_length:, :] + norm_hidden_states_video = ( + self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video) + ) + hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + else: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous() + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + +class TemporalAttention(CrossAttention): + def __init__(self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + rotary_emb=None): + super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups) + # relative time positional embeddings + self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet + self.rotary_emb = rotary_emb + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device) + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + # reshape for adding time positional bais + query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + + if exists(self.rotary_emb): + query = self.rotary_emb.rotate_queries_or_keys(query) + key = self.rotary_emb.rotate_queries_or_keys(key) + + attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key) + + attention_scores = attention_scores + time_rel_pos_bias + + if attention_mask is not None: + # add attention mask + attention_scores = attention_scores + attention_mask + + attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach() + + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + # print(attention_probs[0][0]) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value) + hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)') + return hidden_states + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/models/clip.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..2879f919e9081a373582e4734cded621bab8245d --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/clip.py @@ -0,0 +1,120 @@ +import numpy +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPTextModel + +import transformers +transformers.logging.set_verbosity_error() + +""" +Will encounter following warning: +- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task +or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). +- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model +that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). + +https://github.com/CompVis/stable-diffusion/issues/97 +according to this issue, this warning is safe. + +This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. +You can safely ignore the warning, it is not an error. + +This clip usage is from U-ViT and same with Stable Diffusion. +""" + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77): + def __init__(self, path, device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer") + self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder') + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class TextEmbedder(nn.Module): + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + def __init__(self, path, dropout_prob=0.1): + super().__init__() + self.text_encodder = FrozenCLIPEmbedder(path=path) + self.dropout_prob = dropout_prob + + def token_drop(self, text_prompts, force_drop_ids=None): + """ + Drops text to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob + else: + # TODO + drop_ids = force_drop_ids == 1 + labels = list(numpy.where(drop_ids, "", text_prompts)) + # print(labels) + return labels + + def forward(self, text_prompts, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + text_prompts = self.token_drop(text_prompts, force_drop_ids) + embeddings = self.text_encodder(text_prompts) + return embeddings + + +if __name__ == '__main__': + + r""" + Returns: + + Examples from CLIPTextModel: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base', + dropout_prob=0.00001).to(device) + + text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]] + output = text_encoder(text_prompts=text_prompt, train=False) + print(output.shape) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/models/resnet.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c1307ad9a0fe26b6b1eca913b024c188df2dd86e --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/resnet.py @@ -0,0 +1,212 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/models/temporal_attention.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/temporal_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6c417bc9e98409c22297a4878da689e38d188a24 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/temporal_attention.py @@ -0,0 +1,388 @@ +import torch +from torch import nn +from typing import Optional +from rotary_embedding_torch import RotaryEmbedding +from dataclasses import dataclass +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +import torch.nn.functional as F +from einops import rearrange, repeat +import math + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def exists(x): + return x is not None + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + # print('num head', heads) + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False # No use xformers for temporal attention + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + + # print('before reshpape query shape', query.shape) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # print('after reshape query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + hidden_states = self._attention(query, key, value, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(value.dtype) + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + +class TemporalAttention(CrossAttention): + def __init__(self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + rotary_emb=None): + super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups) + # relative time positional embeddings + self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet + self.rotary_emb = rotary_emb + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device) + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + if exists(self.rotary_emb): + query = self.rotary_emb.rotate_queries_or_keys(query) + key = self.rotary_emb.rotate_queries_or_keys(key) + + attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key) + attention_scores = attention_scores + time_rel_pos_bias + + if attention_mask is not None: + # add attention mask + attention_scores = attention_scores + attention_mask + + attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach() + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.to(value.dtype) + hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value) + hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)') + return hidden_states + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/models/transformer_3d.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/transformer_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b0aba347ea5408955d4646737306c8e945bf16 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/transformer_3d.py @@ -0,0 +1,367 @@ + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange, repeat + +try: + from attention import BasicTransformerBlock +except: + from .attention import BasicTransformerBlock + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer3DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + rotary_emb=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + rotary_emb=rotary_emb, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + use_image_num=None, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: # True + + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + if self.training: + video_length = hidden_states.shape[2] - use_image_num + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states_length = encoder_hidden_states.shape[1] + encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...] + encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous() + else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous() + + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + video_length=video_length, + use_image_num=use_image_num, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous() + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/models/unet.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb2f473142783e7e6cf8f8fb37c41c56fb2f8de --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/unet.py @@ -0,0 +1,617 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import math +import json +import torch +import einops +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + +try: + from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from .resnet import InflatedConv3d +except: + from unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from resnet import InflatedConv3d + +from rotary_embedding_torch import RotaryEmbedding + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, # 64 + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + + # print(use_first_frame) + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # print(only_cross_attention) + # print(type(only_cross_attention)) + # exit() + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + # print(only_cross_attention) + # exit() + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # print(attention_head_dim) + # exit() + + rotary_emb = RotaryEmbedding(32) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + # relative time positional embeddings + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + # print(emb.shape) # torch.Size([3, 1280]) + # print(class_emb.shape) # torch.Size([3, 1280]) + emb = emb + class_emb + + if self.use_relative_position: + frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device) + else: + frame_rel_pos_bias = None + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + use_image_num=use_image_num, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num, + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + use_image_num=use_image_num, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + # print(sample.shape) + + if not return_dict: + return (sample,) + sample = UNet3DConditionOutput(sample=sample) + return sample + + def forward_with_cfg(self, + x, + t, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = None, + cfg_scale=4.0, + use_fp16=False): + """ + Forward, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + if use_fp16: + combined = combined.to(dtype=torch.float16) + model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = model_out[:, :4], model_out[:, 4:] + # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + config["use_first_frame"] = False + + from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin + + + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + for k, v in model.state_dict().items(): + # print(k) + if '_temp' in k: + state_dict.update({k: v}) + if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross + k = k.replace('attn_fcross', 'attn1') + state_dict.update({k: state_dict[k]}) + if 'norm_fcross' in k: + k = k.replace('norm_fcross', 'norm1') + state_dict.update({k: state_dict[k]}) + + model.load_state_dict(state_dict) + + return model + +if __name__ == '__main__': + import torch + # from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + device = "cuda" if torch.cuda.is_available() else "cpu" + + pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-v1-4/" # p cluster + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device) + # unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + unet.enable_xformers_memory_efficient_attention() + unet.enable_gradient_checkpointing() + + unet.train() + + use_image_num = 5 + noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device) + bsz = noisy_latents.shape[0] + timesteps = torch.randint(0, 1000, (bsz,)).to(device) + timesteps = timesteps.long() + encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device) + # class_labels = torch.randn((bsz, )).to(device) + + + model_pred = unet(sample=noisy_latents, timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + class_labels=None, + use_image_num=use_image_num).sample + print(model_pred.shape) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/models/unet_blocks.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..849c10539c7039840c93631c5201069119d3c306 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/unet_blocks.py @@ -0,0 +1,648 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +from torch import nn + +try: + from .attention import Transformer3DModel + from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +except: + from attention import Transformer3DModel + from resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, +): + # print(down_block_type) + # print(use_first_frame) + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, + ): + super().__init__() + resnets = [] + attentions = [] + + # print(use_first_frame) + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def create_custom_forward_attn(module, return_dict=None, use_image_num=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) + else: + return module(*inputs, use_image_num=use_image_num) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + use_image_num=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def create_custom_forward_attn(module, return_dict=None, use_image_num=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) + else: + return module(*inputs, use_image_num=use_image_num) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/models/utils.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76e205d1709568885e58b56888f62c8805ad4f91 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/models/utils.py @@ -0,0 +1,215 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch + +import numpy as np +import torch.nn as nn + +from einops import repeat + + +################################################################################# +# Unet Utils # +################################################################################# + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# class HybridConditioner(nn.Module): + +# def __init__(self, c_concat_config, c_crossattn_config): +# super().__init__() +# self.concat_conditioner = instantiate_from_config(c_concat_config) +# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + +# def forward(self, c_concat, c_crossattn): +# c_concat = self.concat_conditioner(c_concat) +# c_crossattn = self.crossattn_conditioner(c_crossattn) +# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/pipeline_videogen.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/pipeline_videogen.py new file mode 100644 index 0000000000000000000000000000000000000000..90a1e2dc8949f2cad66f32897b683ef06d43e111 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/pipeline_videogen.py @@ -0,0 +1,680 @@ + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import einops +import torch +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + #randn_tensor, + replace_example_docstring, + BaseOutput, +) + +try: + from diffusers.utils import randn_tensor +except: + from diffusers.utils.torch_utils import randn_tensor + +try: + from diffusers.pipeline_utils import DiffusionPipeline +except: + from diffusers import DiffusionPipeline + +from dataclasses import dataclass + +import os, sys +sys.path.append(os.path.split(sys.path[0])[0]) +from ..models.unet import UNet3DConditionModel # Fix import issue + +import numpy as np + +@dataclass +class StableDiffusionPipelineOutput(BaseOutput): + video: torch.Tensor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class VideoGenPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + # self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + # if self.safety_checker is not None: + # cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = einops.rearrange(latents, "b c f h w -> (b f) c h w") + video = self.vae.decode(latents).sample + video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length) + video = ((video / 2 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + video_length: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + video_length, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + # cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + + # 8. Post-processing + video = self.decode_latents(latents) + + return StableDiffusionPipelineOutput(video=video) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/sample.py b/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..a79ccc54af929f4194175aadfe3681973e9aba96 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/sample.py @@ -0,0 +1,95 @@ +import os +import torch +import argparse +import torchvision + +from pipeline_videogen import VideoGenPipeline + +from download import find_model +from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler +from diffusers.models import AutoencoderKL +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from omegaconf import OmegaConf + +import os, sys +sys.path.append(os.path.split(sys.path[0])[0]) +from models import get_models +import imageio + +def main(args): + if args.seed is not None: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + sd_path = args.pretrained_path + "/stable-diffusion-v1-4" + unet = get_models(args, sd_path).to(device, dtype=torch.float16) + state_dict = find_model(args.ckpt_path) + unet.load_state_dict(state_dict) + + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device) + tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer") + text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge + + # set eval mode + unet.eval() + vae.eval() + text_encoder_one.eval() + + if args.sample_method == 'ddim': + scheduler = DDIMScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + elif args.sample_method == 'eulerdiscrete': + scheduler = EulerDiscreteScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + elif args.sample_method == 'ddpm': + scheduler = DDPMScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + else: + raise NotImplementedError + + videogen_pipeline = VideoGenPipeline(vae=vae, + text_encoder=text_encoder_one, + tokenizer=tokenizer_one, + scheduler=scheduler, + unet=unet).to(device) + videogen_pipeline.enable_xformers_memory_efficient_attention() + + if not os.path.exists(args.output_folder): + os.makedirs(args.output_folder) + + video_grids = [] + for prompt in args.text_prompt: + print('Processing the ({}) prompt'.format(prompt)) + videos = videogen_pipeline(prompt, + video_length=args.video_length, + height=args.image_size[0], + width=args.image_size[1], + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale).video + imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0 + + print('save path {}'.format(args.output_folder)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="") + args = parser.parse_args() + config = args.config + + # Overwrite config with command line arguments + if args.optional_args: + modified_config = OmegaConf.from_dotlist(args.optional_args) + config = OmegaConf.merge(config, modified_config) + + main(OmegaConf.load(config)) + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/sample.sh b/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/sample.sh new file mode 100644 index 0000000000000000000000000000000000000000..c8dd5fdada1cf921e5c250bd357e765fd9e42b5a --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/base/pipelines/sample.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES=6 +python pipelines/sample.py --config configs/sample.yaml \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/configs/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/configs/sample.yaml b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/configs/sample.yaml new file mode 100644 index 0000000000000000000000000000000000000000..968cdd6b2240e3dd96b95f2dc8f165b2c1f9f7e0 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/configs/sample.yaml @@ -0,0 +1,37 @@ +args: + input_folder: "../res/base/" + ckpt_path: "../pretrained_models/lavie_interpolation.pt" + pretrained_path: "../pretrained_models" + output_folder: "../res/interpolation/" + seed_list: + - 3418 + + fps_list: + - 24 + + # model config: + model: TSR + num_frames: 61 + image_size: [320, 512] + num_sampling_steps: 50 + vae: mse + use_timecross_transformer: False + frame_interval: 1 + + # sample config: + seed: 0 + cfg_scale: 4.0 + run_time: 12 + use_compile: False + enable_xformers_memory_efficient_attention: True + num_sample: 1 + + additional_prompt: ", 4k." + negative_prompt: "None" + do_classifier_free_guidance: True + use_ddim_sample_loop: True + + researve_frame: 3 + mask_type: "tsr" + use_concat: True + copy_no_mask: True diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/datasets/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1332521f6336d8ce42f026e0aeba66fdd8138a0a --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/datasets/__init__.py @@ -0,0 +1 @@ +from datasets import video_transforms diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/datasets/video_transforms.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/datasets/video_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..463a10f082288f486a546bbff17b136824723589 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/datasets/video_transforms.py @@ -0,0 +1,109 @@ +import torch +import random +import numbers +from torchvision.transforms import RandomCrop, RandomResizedCrop + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class ResizeVideo: + ''' + Resize to the specified size + ''' + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized video clip. + size is (T, C, h, w) + """ + clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + return clip_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class TemporalRandomCrop(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, total_frames): + rand_end = max(0, total_frames - self.size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + self.size, total_frames) + return begin_index, end_index + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9dbf6cf0bd6b9d1a8f65e0a31e9a84cacc03189 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/__init__.py @@ -0,0 +1,47 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + # learn_sigma=True, + learn_sigma=False, # for unet + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/diffusion_utils.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/gaussian_diffusion.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..854c5c4daceace3826b786d98220c9d1b10611c8 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/gaussian_diffusion.py @@ -0,0 +1,1000 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, + mask=None, x_start=None, use_concat=False, + copy_no_mask=False, ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, F, C = x.shape[:3] + assert t.shape == (B,) + # model_output = model(x, t, **model_kwargs) + if copy_no_mask: + if use_concat: + try: + model_output = model(th.concat([x, x_start], dim=1), t, **model_kwargs).sample + except: + # print(f'x.shape = {x.shape}, x_start.shape = {x_start.shape}') + # ) + # x.shape = torch.Size([2, 4, 61, 32, 32]), x_start.shape = torch.Size([2, 4, 61, 32, 32] + # print(f'x[0,0,:,0,0] = {x[0,0,:,0,0]}, \nx_start[0,0,:,0,0] = {x_start[0,0,:,0,0]}') + model_output = model(th.concat([x, x_start], dim=1), t, **model_kwargs) + else: + try: + model_output = model(x, t, **model_kwargs).sample # for tav unet + except: + model_output = model(x, t, **model_kwargs) + else: + if use_concat: + try: + model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs).sample + except: + model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs) + else: + try: + model_output = model(x, t, **model_kwargs).sample # for tav unet + except: + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output, mask=mask, x_start=x_start, use_concat=use_concat) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps, mask=None, x_start=None, use_concat=False): # (x_t=x, t=t, eps=model_output) + assert x_t.shape == eps.shape + if not use_concat: + if mask is not None: + if x_start is None: + return ( + (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps )* mask + x_t * (1-mask) + ) + else: + # breakpoint() + if (t == 0).any(): + print('t=0') + x_unknown = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t \ + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + return x_start * (1-mask) + x_unknown * mask + else: + x_known = self.q_sample(x_start, t-1) + x_unknown = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t \ + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + return ( + x_known * (1-mask) + x_unknown * mask + ) + else: + return ( + (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) + ) + else: + return ( + (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) + ) + + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + mask=None, + x_start=None, + use_concat=False + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + x_start=None, + use_concat=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + mask=mask, + x_start=x_start, + use_concat=use_concat + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + x_start=None, + use_concat=False + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): # loop + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + mask=None, + x_start=None, + use_concat=False, + copy_no_mask=False, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat, + copy_no_mask=copy_no_mask, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + mask=None, + x_start=None, + use_concat=False, + copy_no_mask=False, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + mask=mask, + x_start=x_start, + use_concat=use_concat, + copy_no_mask=copy_no_mask, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + mask=None, + x_start=None, + use_concat=False, + copy_no_mask=False, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + mask=mask, + x_start=x_start, + use_concat=use_concat, + copy_no_mask=copy_no_mask, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None, t_head=None, copy_no_mask=False): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + # mask could be here + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1) + # mask is used for (0,0,0,1,1,1,...) which means the diffusion model can see the first 3 frames of the input video + # print(f'training_losses(): mask = {mask}') # None + + if mask is not None: + x_t = x_t*mask + x_start*(1-mask) + + # noise augmentation + if copy_no_mask: + if t_head is not None: + noise_aug = self.q_sample(x_start[:, 4:], t_head) # noise aug on copied_video + x_t = th.cat([x_t[:, :4], noise_aug], dim=1) + else: + if t_head is not None: + noise_aug = self.q_sample(x_start[:, 5:], t_head) # b, 4, f, h, w + noise_aug = noise_aug * (x_start[:, 4].unsqueeze(1).expand(-1, 4, -1, -1, -1) == 0) # use mask to zero out augmented noises + x_t = th.cat([x_t[:, :5], noise_aug], dim=1) + terms = {} + # for i in [0,1,2,3,4,5,6,7]: + # print(f'x_t[0,{i},:,0,0] = {x_t[0,i,:,0,0]}') + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + # print(f'self.loss_type = {self.loss_type}') # LossType.MSE + # model_output = model(x_t, t, **model_kwargs) + try: + model_output = model(x_t, t, **model_kwargs).sample # for tav unet + except: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, F, C = x_t.shape[:3] + assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + # print(f'self.model_mean_type = {self.model_mean_type}') # ModelMeanType.EPSILON + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + # assert model_output.shape == target.shape == x_start.shape + # if mask is not None: + # nonzero_idx = th.nonzero(1-mask) + terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2) + # else: + # terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/respace.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..31c0092380367db477f154a39bf172ff64712fc4 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/respace.py @@ -0,0 +1,130 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import torch +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/timestep_sampler.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/download.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/download.py new file mode 100644 index 0000000000000000000000000000000000000000..124ec5da032c21879127d6a698bd628420e837f1 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/download.py @@ -0,0 +1,9 @@ +import os +import torch + +def find_model(model_name): + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + if "ema" in checkpoint: # supports checkpoints from train.py + checkpoint = checkpoint["ema"] + return checkpoint + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..507091b8c274533eca620577bae301541d5a91f6 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/__init__.py @@ -0,0 +1,33 @@ +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from .unet import UNet3DConditionModel +from torch.optim.lr_scheduler import LambdaLR + +def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit + from torch.optim.lr_scheduler import LambdaLR + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'warmup': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + +def get_models(args, ckpt_path): + + if 'TSR' in args.model: + return UNet3DConditionModel.from_pretrained_2d(ckpt_path, subfolder="unet", use_concat=args.use_concat, copy_no_mask=args.copy_no_mask) + else: + raise '{} Model Not Supported!'.format(args.model) + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/attention.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1c05879560bd827c2096e557dcaa17fadc8c1bb0 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/attention.py @@ -0,0 +1,665 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from dataclasses import dataclass +from typing import Optional + +import math +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm + +from einops import rearrange, repeat + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + # print(use_relative_position) + self.use_relative_position = use_relative_position + if self.use_relative_position: + # print(dim_head) + # print(heads) + # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265 + self.max_position_embeddings = 32 + self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head) + + self.dropout = nn.Dropout(dropout) + + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + # if self.use_relative_position: + # print('before attention query shape', query.shape) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # if self.use_relative_position: + # print('before attention query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + if self.use_relative_position: + query = self.reshape_for_scores(self.reshape_batch_dim_to_heads(query)) + key = self.reshape_for_scores(self.reshape_batch_dim_to_heads(key)) + value = self.reshape_for_scores(self.reshape_batch_dim_to_heads(value)) + + # torch.baddbmm only accepte 3-D tensor + # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm + attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2)) + + # print('attention_scores shape', attention_scores.shape) + + # print(query.shape) # [b (h w)] nd f d + query_length, key_length = query.shape[2], key.shape[2] + # print('query shape', query.shape) + # print('key shape', key.shape) + position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) # hidden_states.device + position_ids_r = torch.arange(key_length, dtype=torch.long, device=key.device).view(1, -1) # hidden_states.device + distance = position_ids_l - position_ids_r + # print('distance shape', distance.shape) + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility + # print('positional_embedding shape', positional_embedding.shape) + relative_position_scores_query = torch.einsum("bhld, lrd -> bhlr", query, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd, lrd -> bhlr", key, positional_embedding) + # print('relative_position_scores_key shape', relative_position_scores_key.shape) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + # print(attention_scores.shape) + + attention_scores = attention_scores / math.sqrt(self.dim_head) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value) + # print(hidden_states.shape) + hidden_states = self.same_batch_dim_to_heads(hidden_states) + # print(hidden_states.shape) + # exit() + + else: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + # print(attention_probs.shape) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + # print(attention_probs.shape) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + # print(hidden_states.shape) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # print(hidden_states.shape) + # exit() + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + # print(only_cross_attention) + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_first_frame = use_first_frame + + # SC-Attn + if use_first_frame: + self.attn1 = SparseCausalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + # print(cross_attention_dim) + else: + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + self.attn_temp = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + use_relative_position=use_relative_position + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op=None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + if self.use_first_frame: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/clip.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d77f4d1a725a5d7ed0c4e10a69602e80b91fec1 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/clip.py @@ -0,0 +1,124 @@ +import numpy +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPTextModel + +import transformers +transformers.logging.set_verbosity_error() + +""" +Will encounter following warning: +- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task +or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). +- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model +that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). + +https://github.com/CompVis/stable-diffusion/issues/97 +according to this issue, this warning is safe. + +This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. +You can safely ignore the warning, it is not an error. + +This clip usage is from U-ViT and same with Stable Diffusion. +""" + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, sd_path, device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer", use_fast=False) + self.transformer = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder") + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class TextEmbedder(nn.Module): + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + def __init__(self, args, dropout_prob=0.1): + super().__init__() + self.text_encodder = FrozenCLIPEmbedder(args) + self.dropout_prob = dropout_prob + + def token_drop(self, text_prompts, force_drop_ids=None): + """ + Drops text to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob + else: + # TODO + drop_ids = force_drop_ids == 1 + labels = list(numpy.where(drop_ids, "None", text_prompts)) + # print(labels) + return labels + + def forward(self, text_prompts, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + text_prompts = self.token_drop(text_prompts, force_drop_ids) + embeddings = self.text_encodder(text_prompts) + return embeddings + + +if __name__ == '__main__': + + r""" + Returns: + + Examples from CLIPTextModel: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + text_encoder = TextEmbedder(dropout_prob=0.00001).to(device) + text_encoder1 = FrozenCLIPEmbedder().to(device) + + text_prompt = ["a photo of a cat", "a photo of a dog", 'a photo of a dog human'] + # text_prompt = ('None', 'None', 'None') + output = text_encoder(text_prompts=text_prompt, train=True) + output1 = text_encoder1(text_prompt) + # print(output) + print(output.shape) + print(output1.shape) + print((output==output1).all()) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/resnet.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c1307ad9a0fe26b6b1eca913b024c188df2dd86e --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/resnet.py @@ -0,0 +1,212 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/unet.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..66b161f18f6552a4a9f7838c461c9ed10c26c7bf --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/unet.py @@ -0,0 +1,576 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import json + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + +try: + from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from .resnet import InflatedConv3d +except: + from unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from resnet import InflatedConv3d + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, # 64 + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + + # print(use_first_frame) + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + # print(only_cross_attention) + # exit() + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # print(attention_head_dim) + # exit() + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + # print(emb.shape) # torch.Size([3, 1280]) + # print(class_emb.shape) # torch.Size([3, 1280]) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + # print(sample.shape) + + if not return_dict: + return (sample,) + sample = UNet3DConditionOutput(sample=sample) + return sample + + def forward_with_cfg(self, + x, + t, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = None, + cfg_scale=4.0): + """ + Forward, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :4], model_out[:, 4:] + eps, rest = model_out[:, :4], model_out[:, 4:] # b c f h w + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False, copy_no_mask=False): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + config["use_first_frame"] = True + + if copy_no_mask: + config["in_channels"] = 8 + else: + if use_concat: + config["in_channels"] = 9 + + + from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin + + + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + + if use_concat: + new_state_dict = {} + conv_in_weight = state_dict["conv_in.weight"] + + print(f'from_pretrained_2d copy_no_mask = {copy_no_mask}') + if copy_no_mask: + new_conv_in_channel = 8 + new_conv_in_list = [0, 1, 2, 3, 4, 5, 6, 7] + else: + new_conv_in_channel = 9 + new_conv_in_list = [0, 1, 2, 3, 4, 5, 6, 7, 8] + new_conv_weight = torch.zeros((conv_in_weight.shape[0], new_conv_in_channel, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype) + + for i, j in zip([0, 1, 2, 3], new_conv_in_list): + new_conv_weight[:, j] = conv_in_weight[:, i] + new_state_dict["conv_in.weight"] = new_conv_weight + new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"] + for k, v in model.state_dict().items(): + # print(k) + if '_temp.' in k: + new_state_dict.update({k: v}) + elif 'conv_in' in k: + continue + else: + new_state_dict[k] = v + # # tmp + # if 'class_embedding' in k: + # state_dict.update({k: v}) + # breakpoint() + model.load_state_dict(new_state_dict) + else: + for k, v in model.state_dict().items(): + # print(k) + if '_temp.' in k: + state_dict.update({k: v}) + model.load_state_dict(state_dict) + return model + +if __name__ == '__main__': + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + pretrained_model_path = "/nvme/maxin/work/large-dit-video/pretrained/stable-diffusion-v1-4/" # 43 + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device) + + noisy_latents = torch.randn((3, 4, 16, 32, 32)).to(device) + bsz = noisy_latents.shape[0] + timesteps = torch.randint(0, 1000, (bsz,)).to(device) + timesteps = timesteps.long() + encoder_hidden_states = torch.randn((bsz, 77, 768)).to(device) + class_labels = torch.randn((bsz, )).to(device) + + model_pred = unet(sample=noisy_latents, timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + class_labels=class_labels).sample + print(model_pred.shape) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/unet_blocks.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..0bfed25db240043dcc42b7e4459f5ca52d3cd902 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/unet_blocks.py @@ -0,0 +1,619 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +from torch import nn + +try: + from .attention import Transformer3DModel + from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +except: + from attention import Transformer3DModel + from resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, +): + # print(down_block_type) + # print(use_first_frame) + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + ): + super().__init__() + resnets = [] + attentions = [] + + # print(use_first_frame) + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/utils.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6cfac1ff6c2c99c87920372482251dc2b2fce34 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/models/utils.py @@ -0,0 +1,215 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch + +import numpy as np +import torch.nn as nn + +from einops import repeat + + +################################################################################# +# Unet Utils # +################################################################################# + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# class HybridConditioner(nn.Module): + +# def __init__(self, c_concat_config, c_crossattn_config): +# super().__init__() +# self.concat_conditioner = instantiate_from_config(c_concat_config) +# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + +# def forward(self, c_concat, c_crossattn): +# c_concat = self.concat_conioner(c_concat) +# c_crossattn = self.crossattn_conditioner(c_crossattn) +# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/sample.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..ce8ece5acab34e64bb1f5430c494feed25e90401 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/sample.py @@ -0,0 +1,309 @@ +""" +we introduce a temporal interpolation network to enhance the smoothness of generated videos and synthesize richer temporal details. +This network takes a 16-frame base video as input and produces an upsampled output consisting of 61 frames. +""" + +import os +import sys +import math +try: + import utils + + from diffusion import create_diffusion + from download import find_model +except: + sys.path.append(os.path.split(sys.path[0])[0]) + + import utils + + from diffusion import create_diffusion + from download import find_model + +import torch +import argparse +import torchvision + +from einops import rearrange +from models import get_models +from torchvision.utils import save_image +from diffusers.models import AutoencoderKL +from models.clip import TextEmbedder +from omegaconf import OmegaConf +from PIL import Image +import numpy as np +from torchvision import transforms +sys.path.append("..") +from datasets import video_transforms +from decord import VideoReader +from utils import mask_generation, mask_generation_before +from natsort import natsorted +from diffusers.utils.import_utils import is_xformers_available + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +def get_input(args): + input_path = args.input_path + transform_video = transforms.Compose([ + video_transforms.ToTensorVideo(), # TCHW + video_transforms.ResizeVideo((args.image_h, args.image_w)), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) + if input_path is not None: + print(f'loading video from {input_path}') + if os.path.isdir(input_path): + file_list = os.listdir(input_path) + video_frames = [] + for file in file_list: + if file.endswith('jpg') or file.endswith('png'): + image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0) + video_frames.append(image) + else: + continue + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + return video_frames, n + elif os.path.isfile(input_path): + _, full_file_name = os.path.split(input_path) + file_name, extention = os.path.splitext(full_file_name) + if extention == '.mp4': + video_reader = VideoReader(input_path) + total_frames = len(video_reader) + start_frame_ind, end_frame_ind = temporal_sample_func(total_frames) + frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, args.num_frames, dtype=int) + video_frames = torch.from_numpy(video_reader.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() + video_frames = transform_video(video_frames) + n = 3 + del video_reader + return video_frames, n + else: + raise TypeError(f'{extention} is not supported !!') + else: + raise ValueError('Please check your path input!!') + else: + print('given video is None, using text to video') + video_frames = torch.zeros(16,3,args.latent_h,args.latent_w,dtype=torch.uint8) + args.mask_type = 'all' + video_frames = transform_video(video_frames) + n = 0 + return video_frames, n + + +def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,): + + + b,f,c,h,w=video_input.shape + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w + + masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() + masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) + masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() + mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) + + + masked_video = torch.cat([masked_video] * 2) if args.do_classifier_free_guidance else masked_video + mask = torch.cat([mask] * 2) if args.do_classifier_free_guidance else mask + z = torch.cat([z] * 2) if args.do_classifier_free_guidance else z + + prompt_all = [prompt] + [args.negative_prompt] if args.do_classifier_free_guidance else [prompt] + text_prompt = text_encoder(text_prompts=prompt_all, train=False) + model_kwargs = dict(encoder_hidden_states=text_prompt, class_labels=None) + + if args.use_ddim_sample_loop: + samples = diffusion.ddim_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \ + progress=True, device=device, mask=mask, x_start=masked_video, use_concat=args.use_concat + ) + else: + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \ + progress=True, device=device, mask=mask, x_start=masked_video, use_concat=args.use_concat + ) # torch.Size([2, 4, 16, 32, 32]) + samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] + + video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] + video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] + return video_clip + + +def auto_inpainting_copy_no_mask(args, video_input, prompt, vae, text_encoder, diffusion, model, device,): + + b,f,c,h,w=video_input.shape + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + + video_input = rearrange(video_input, 'b f c h w -> (b f) c h w').contiguous() + video_input = vae.encode(video_input).latent_dist.sample().mul_(0.18215) + video_input = rearrange(video_input, '(b f) c h w -> b c f h w', b=b).contiguous() + + lr_indice = torch.IntTensor([i for i in range(0,62,4)]).to(device) + copied_video = torch.index_select(video_input, 2, lr_indice) + copied_video = torch.repeat_interleave(copied_video, 4, dim=2) + copied_video = copied_video[:,:,1:-2,:,:] + copied_video = torch.cat([copied_video] * 2) if args.do_classifier_free_guidance else copied_video + + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w + z = torch.cat([z] * 2) if args.do_classifier_free_guidance else z + + prompt_all = [prompt] + [args.negative_prompt] if args.do_classifier_free_guidance else [prompt] + text_prompt = text_encoder(text_prompts=prompt_all, train=False) + model_kwargs = dict(encoder_hidden_states=text_prompt, class_labels=None) + + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + if args.use_ddim_sample_loop: + samples = diffusion.ddim_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \ + progress=True, device=device, mask=None, x_start=copied_video, use_concat=args.use_concat, copy_no_mask=args.copy_no_mask, + ) + else: + raise ValueError(f'We only have ddim sampling implementation for now') + + samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] + + video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] + video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] + return video_clip + + + +def main(args): + + for seed in args.seed_list: + + args.seed = seed + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + # print(f'torch.seed() = {torch.seed()}') + + print('sampling begins') + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "cpu" + + ckpt_path = args.ckpt_path + sd_path = args.pretrained_path + "/stable-diffusion-v1-4" + for ckpt in [ckpt_path]: + + ckpt_num = str(ckpt_path).zfill(7) + + # Load model: + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + args.image_h = args.image_size[0] + args.image_w = args.image_size[1] + args.latent_h = latent_h + args.latent_w = latent_w + print(f'args.copy_no_mask = {args.copy_no_mask}') + model = get_models(args, sd_path).to(device) + + if args.use_compile: + model = torch.compile(model) + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + model.enable_xformers_memory_efficient_attention() + # model.enable_vae_slicing() # ziqi added + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Auto-download a pre-trained model or load a custom checkpoint from train.py: + print(f'loading model from {ckpt_path}') + + # load ckpt + state_dict = find_model(ckpt_path) + print('loading succeed') + model.load_state_dict(state_dict) + + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + model.eval() # important! + diffusion = create_diffusion(str(args.num_sampling_steps)) + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(device) + text_encoder = TextEmbedder(sd_path).to(device) + + video_list = os.listdir(args.input_folder) + args.input_path_list = [os.path.join(args.input_folder, video) for video in video_list] + for input_path in args.input_path_list: + + args.input_path = input_path + + print(f'=======================================') + if not args.input_path.endswith('.mp4'): + print(f'Skipping {args.input_path}') + continue + + print(f'args.input_path = {args.input_path}') + + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + # Labels to condition the model with (feel free to change): + video_name = args.input_path.split('/')[-1].split('.mp4')[0] + args.prompt = [video_name] + print(f'args.prompt = {args.prompt}') + prompts = args.prompt + class_name = [p + args.additional_prompt for p in prompts] + + if not os.path.exists(os.path.join(args.output_folder)): + os.makedirs(os.path.join(args.output_folder)) + video_input, researve_frames = get_input(args) # f,c,h,w + video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w + if args.copy_no_mask: + pass + else: + mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w + + if args.copy_no_mask: + pass + else: + if args.mask_type == 'tsr': + masked_video = video_input * (mask == 0) + else: + masked_video = video_input * (mask == 0) + + all_video = [] + if researve_frames != 0: + all_video.append(video_input) + for idx, prompt in enumerate(class_name): + if idx == 0: + if args.copy_no_mask: + video_clip = auto_inpainting_copy_no_mask(args, video_input, prompt, vae, text_encoder, diffusion, model, device,) + video_clip_ = video_clip.unsqueeze(0) + all_video.append(video_clip_) + else: + video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,) + video_clip_ = video_clip.unsqueeze(0) + all_video.append(video_clip_) + else: + raise NotImplementedError + masked_video = video_input * (mask == 0) + video_clip = auto_inpainting_copy_no_mask(args, video_clip.unsqueeze(0), masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,) + video_clip_ = video_clip.unsqueeze(0) + all_video.append(video_clip_[:, 3:]) + video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) + for fps in args.fps_list: + save_path = args.output_folder + if not os.path.exists(os.path.join(save_path)): + os.makedirs(os.path.join(save_path)) + local_save_path = os.path.join(save_path, f'{video_name}.mp4') + print(f'save in {local_save_path}') + torchvision.io.write_video(local_save_path, video_, fps=fps) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + args = parser.parse_args() + main(**OmegaConf.load(args.config)) + + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/utils.py b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3219c0af49cb1bd039abe6899aab0ec85a4a4b1e --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/interpolation/utils.py @@ -0,0 +1,371 @@ +import os +import math +import torch +import logging +import subprocess +import numpy as np +import torch.distributed as dist + +# from torch._six import inf +from torch import inf +from PIL import Image +from typing import Union, Iterable +from collections import OrderedDict + + +_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] + +################################################################################# +# Training Helper Functions # +################################################################################# + +################################################################################# +# Training Clip Gradients # +################################################################################# + +def get_grad_norm( + parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + return total_norm + +def clip_grad_norm_( + parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, + error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + # print(total_norm) + + if clip_grad: + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f'The total norm of order {norm_type} for gradients from ' + '`parameters` is non-finite, so it cannot be clipped. To disable ' + 'this error and scale the gradients by the non-finite norm anyway, ' + 'set `error_if_nonfinite=False`') + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for g in grads: + g.detach().mul_(clip_coef_clamped.to(g.device)) + # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + # print(gradient_cliped) + return total_norm + +################################################################################# +# Training Logger # +################################################################################# + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + # format='[\033[34m%(asctime)s\033[0m] %(message)s', + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + +def create_accelerate_logger(logging_dir, is_main_process=False): + """ + Create a logger that writes to a log file and stdout. + """ + if is_main_process: # real logger + logging.basicConfig( + level=logging.INFO, + # format='[\033[34m%(asctime)s\033[0m] %(message)s', + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def create_tensorboard(tensorboard_dir): + """ + Create a tensorboard that saves losses. + """ + if dist.get_rank() == 0: # real tensorboard + # tensorboard + writer = SummaryWriter(tensorboard_dir) + + return writer + +def write_tensorboard(writer, *args): + ''' + write the loss information to a tensorboard file. + Only for pytorch DDP mode. + ''' + if dist.get_rank() == 0: # real tensorboard + writer.add_scalar(args[0], args[1], args[2]) + +################################################################################# +# EMA Update/ DDP Training Utils # +################################################################################# + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + +def cleanup(): + """ + End DDP training. + """ + dist.destroy_process_group() + + +def setup_distributed(backend="nccl", port=None): + """Initialize distributed training environment. + support both slurm and torch.distributed.launch + see torch.distributed.init_process_group() for more details + """ + num_gpus = torch.cuda.device_count() + + print(f'Hahahahahaha') + if "SLURM_JOB_ID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + # specify master port + if port is not None: + os.environ["MASTER_PORT"] = str(port) + elif "MASTER_PORT" not in os.environ: + # os.environ["MASTER_PORT"] = "29566" + os.environ["MASTER_PORT"] = str(29566 + num_gpus) + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank % num_gpus) + os.environ["RANK"] = str(rank) + else: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # torch.cuda.set_device(rank % num_gpus) + + print(f'before dist.init_process_group') + + dist.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + ) + print(f'after dist.init_process_group') + +################################################################################# +# Testing Utils # +################################################################################# + +def save_video_grid(video, nrow=None): + b, t, h, w, c = video.shape + + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = torch.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype=torch.uint8) + + print(video_grid.shape) + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + + return video_grid + + +################################################################################# +# MMCV Utils # +################################################################################# + + +def collect_env(): + # Copyright (c) OpenMMLab. All rights reserved. + from mmcv.utils import collect_env as collect_base_env + from mmcv.utils import get_git_hash + """Collect the information of the running environments.""" + + env_info = collect_base_env() + env_info['MMClassification'] = get_git_hash()[:7] + + for name, val in env_info.items(): + print(f'{name}: {val}') + + print(torch.cuda.get_arch_list()) + print(torch.version.cuda) + +################################################################################# +# Long video generation Utils # +################################################################################# + +def mask_generation(mask_type, shape, dtype, device): + b, c, f, h, w = shape + if mask_type.startswith('random'): + num = float(mask_type.split('random')[-1]) + mask_f = torch.ones(1, 1, f, 1, 1, dtype=dtype, device=device) + indices = torch.randperm(f, device=device)[:int(f*num)] + mask_f[0, 0, indices, :, :] = 0 + mask = mask_f.expand(b, c, -1, h, w) + elif mask_type.startswith('first'): + num = int(mask_type.split('first')[-1]) + mask_f = torch.cat([torch.zeros(1, 1, num, 1, 1, dtype=dtype, device=device), + torch.ones(1, 1, f-num, 1, 1, dtype=dtype, device=device)], dim=2) + mask = mask_f.expand(b, c, -1, h, w) + else: + raise ValueError(f"Invalid mask type: {mask_type}") + return mask + + + +def mask_generation_before(mask_type, shape, dtype, device): + b, f, c, h, w = shape + if mask_type.startswith('random'): + num = float(mask_type.split('random')[-1]) + mask_f = torch.ones(1, f, 1, 1, 1, dtype=dtype, device=device) + indices = torch.randperm(f, device=device)[:int(f*num)] + mask_f[0, indices, :, :, :] = 0 + mask = mask_f.expand(b, -1, c, h, w) + elif mask_type.startswith('first'): + num = int(mask_type.split('first')[-1]) + mask_f = torch.cat([torch.zeros(1, num, 1, 1, 1, dtype=dtype, device=device), + torch.ones(1, f-num, 1, 1, 1, dtype=dtype, device=device)], dim=1) + mask = mask_f.expand(b, -1, c, h, w) + elif mask_type.startswith('uniform'): + p = float(mask_type.split('uniform')[-1]) + mask_f = torch.ones(1, f, 1, 1, 1, dtype=dtype, device=device) + mask_f[0, torch.rand(f, device=device) < p, :, :, :] = 0 + print(f'mask_f: = {mask_f}') + mask = mask_f.expand(b, -1, c, h, w) + print(f'mask.shape: = {mask.shape}, mask: = {mask}') + elif mask_type.startswith('all'): + mask = torch.ones(b,f,c,h,w,dtype=dtype,device=device) + elif mask_type.startswith('onelast'): + num = int(mask_type.split('onelast')[-1]) + mask_one = torch.zeros(1,1,1,1,1, dtype=dtype, device=device) + mask_mid = torch.ones(1,f-2*num,1,1,1,dtype=dtype, device=device) + mask_last = torch.zeros_like(mask_one) + mask = torch.cat([mask_one]*num + [mask_mid] + [mask_last]*num, dim=1) + # breakpoint() + mask = mask.expand(b, -1, c, h, w) + elif mask_type.startswith('interpolate'): + mask_f = [] + for i in range(4): + mask_zero = torch.zeros(1,1,1,1,1, dtype=dtype, device=device) + mask_f.append(mask_zero) + mask_one = torch.ones(1,3,1,1,1, dtype=dtype, device=device) + mask_f.append(mask_one) + mask = torch.cat(mask_f, dim=1) + print(f'mask={mask}') + elif mask_type.startswith('tsr'): + mask_f = [] + mask_zero = torch.zeros(1,1,1,1,1, dtype=dtype, device=device) + mask_one = torch.ones(1,3,1,1,1, dtype=dtype, device=device) + for i in range(15): + mask_f.append(mask_zero) # not masked + mask_f.append(mask_one) # masked + mask_f.append(mask_zero) # not masked + mask = torch.cat(mask_f, dim=1) + # print(f'before mask.shape = {mask.shape}, mask = {mask}') # [1, 61, 1, 1, 1] + mask = mask.expand(b, -1, c, h, w) + # print(f'after mask.shape = {mask.shape}, mask = {mask}') # [4, 61, 3, 256, 256] + else: + raise ValueError(f"Invalid mask type: {mask_type}") + + return mask diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/sample.yaml b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/sample.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c012c24de570b1d6324a21b952d308be8f05c18d --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/sample.yaml @@ -0,0 +1,7 @@ +ckpt_path: "../pretrained_models/lavie_vsr.pt" +pretrained_path: "../pretrained_models" +input_path: "../res/base" +output_path: "../res/vsr" +noise_level: 50 +guidance_scale: 5 +inference_steps: 50 diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/unet_3d_config.json b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/unet_3d_config.json new file mode 100644 index 0000000000000000000000000000000000000000..ccebb1884d60b9f309b0b3fd45f9d284034e1728 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/unet_3d_config.json @@ -0,0 +1,66 @@ +{ + "_class_name": "UNet3DVSRModel", + "_diffusers_version": "0.9.0.dev0", + "_name_or_path": "hf-models/stable-diffusion-x4-upscaler/unet", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [ + 256, + 512, + 512, + 1024 + ], + "center_input_sample": false, + "cross_attention_dim": 1024, + "down_block_types": [ + "DownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 7, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": 1000, + "only_cross_attention": [ + true, + true, + true, + false + ], + "out_channels": 4, + "sample_size": 128, + "up_block_types": [ + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "UpBlock3D" + ], + "use_linear_projection": true, + + "down_temporal_idx": [0, 1, 2, 3], + "mid_temporal": true, + "up_temporal_idx": [0, 1, 2, 3], + "temporal_module_config": { + "num_attention_layers": 1, + "attention_block_types": [ + "", + "" + ], + "cross_frame_attention_mode": "0_i-1_i", + "temporal_shift_fold_div": 2, + "temporal_shift_direction": "right", + "use_dcn_warpping": false, + "use_deformable_conv": true, + "attention_dim_div": 2 + }, + "use_first_frame": false, + "video_condition": false, + "freeze_pretrained_2d_upsampler": true +} diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/vae_config.json b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/vae_config.json new file mode 100644 index 0000000000000000000000000000000000000000..3dca8898737c4bbb92ccd502b080507483ad1728 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/configs/vae_config.json @@ -0,0 +1,28 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.9.0.dev0", + "_name_or_path": "hf-models/stable-diffusion-x4-upscaler/vae", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "scaling_factor": 0.08333 +} diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d171551ccbeff3f7d38a6e89624815f7ebd1e4db --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/__init__.py @@ -0,0 +1,54 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +# !important +def create_diffusion( + timestep_respacing="", + noise_schedule="linear", # 'linear' for training + use_kl=False, + rescale_learned_sigmas=False, + prediction_type='v_prediction', + variance_type='fixed_small', + beta_start=0.0001, + beta_end=0.02, + # beta_start=0.00085, + # beta_end=0.012, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps, beta_start=beta_start, beta_end=beta_end) + if prediction_type == 'epsilon': + model_mean_type = gd.ModelMeanType.EPSILON # EPSILON type for stable-diffusion-2-1 512 + elif prediction_type == 'xstart': + model_mean_type = gd.ModelMeanType.START_X + elif prediction_type == 'v_prediction': + model_mean_type = gd.ModelMeanType.PREVIOUS_V # gd.ModelMeanType.PREVIOUS_V for stable-diffusion-2-1 768/x4-upscaler + + if variance_type == 'fixed_small': + model_var_type = gd.ModelVarType.FIXED_SMALL + elif variance_type == 'fixed_large': + model_var_type = gd.ModelVarType.FIXED_LARGE + elif variance_type == 'learned_range': + model_var_type = gd.ModelVarType.LEARNED_RANGE + + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(model_mean_type), + model_var_type=(model_var_type), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/diffusion_utils.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/gaussian_diffusion.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..be197ae8b899ecb5e763226b8c56de335fa123d6 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/gaussian_diffusion.py @@ -0,0 +1,923 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_V = enum.auto() # v-parameterization for VSR; (see section 2.4 https://imagen.research.google/video/paper.pdf) + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "scaled_linear": + betas = (np.linspace(beta_start**0.5, beta_end**0.5, num_diffusion_timesteps, dtype=np.float64)** 2) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, beta_start, beta_end): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name in ["linear", "scaled_linear"]: + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + schedule_name, + beta_start=scale * beta_start, + beta_end=scale * beta_end, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def get_v(self, x_start, noise, t): + # v-prediction parameterization + # training loss type for stable-diffusion-2-1 768/x4-upscaler + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start + ) + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, F, C = x.shape[:3] + assert t.shape == (B,) + # model_output = model(x, t, **model_kwargs) + try: + model_output = model(x, t, **model_kwargs).sample # for tav unet + # print(model_output.shape) + except: + model_output = model(x, t, **model_kwargs) + + # for v prediction + # if self.model_mean_type == ModelMeanType.PREVIOUS_V: + # model_output = self._predict_eps_from_z_and_v(x, t, model_output) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + if self.model_mean_type == ModelMeanType.PREVIOUS_V: + pred_xstart = process_xstart( + self._predict_xstart_from_z_and_v(x_t=x, t=t, v=model_output) + ) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + + if self.model_mean_type == ModelMeanType.PREVIOUS_V: + eps = self._predict_eps_from_z_and_v(x_t=x, t=t, v=model_output) + else: + eps = model_output + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "eps": eps, + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + # for v prediction + def _predict_xstart_from_z_and_v(self, x_t, t, v): + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + # for v prediction + def _predict_eps_from_z_and_v(self, x_t, t, v): + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + if self.model_mean_type == ModelMeanType.PREVIOUS_V: + eps = out["eps"] + else: + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + + def training_losses(self, model, x_start, t, loss_mask=None, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: # loss_type=LossType.MSE by default + try: + model_output = model(x_t, t, **model_kwargs).sample # for tav unet + except: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, F, C = x_t.shape[:3] + assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + print("=======") + + target = { + ModelMeanType.PREVIOUS_V: self.get_v(x_start=x_start, noise=noise, t=t), + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + if loss_mask is not None: + terms["mse"] = mean_flat(((target - model_output) ** 2) * loss_mask) + else: + terms["mse"] = mean_flat((target - model_output) ** 2) + + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/respace.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..31c0092380367db477f154a39bf172ff64712fc4 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/respace.py @@ -0,0 +1,130 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import torch +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/scheduling_ddim.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/scheduling_ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..8564006cac2d477a645178c01bb8ac18b871eacd --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/scheduling_ddim.py @@ -0,0 +1,462 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput + +try: + from diffusers.utils import randn_tensor +except: + from diffusers.utils.torch_utils import randn_tensor + +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + # """ + # Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + # Args: + # num_inference_steps (`int`): + # the number of diffusion steps used when generating samples with a pre-trained model. + # """ + + # if num_inference_steps > self.config.num_train_timesteps: + # raise ValueError( + # f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + # f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + # f" maximal {self.config.num_train_timesteps} timesteps." + # ) + + # self.num_inference_steps = num_inference_steps + # step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # # creates integer timesteps by multiplying by ratio + # # casting to int to avoid issues when num_inference_step is power of 3 + # timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + # self.timesteps = torch.from_numpy(timesteps).to(device) + # self.timesteps += self.config.steps_offset + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.linspace(self.config.steps_offset, self.config.num_train_timesteps, num_inference_steps) + timesteps = timesteps.round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.config.steps_offset + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # print('===========', self.config.prediction_type) + # self.config.prediction_type = "v_prediction" + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/timestep_sampler.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/__init__.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23fe8e3be56d23ed269e7e9dfd89987eb68e47ea --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/__init__.py @@ -0,0 +1,28 @@ +from .unet import UNet3DVSRModel +from torch.optim.lr_scheduler import LambdaLR + +def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit + from torch.optim.lr_scheduler import LambdaLR + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'warmup': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + +def get_models(): + config_path = "./configs/unet_3d_config.json" + pretrained_model_path = "./pretrained_models/upscaler4x/unet/diffusion_pytorch_model.bin" + return UNet3DVSRModel.from_pretrained_2d(config_path, pretrained_model_path) + + \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/attention.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4f4514752d23e916d50c68e18a7a672a3b1fcf --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/attention.py @@ -0,0 +1,826 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from dataclasses import dataclass +from typing import Optional + +import math +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm + +try: + from .resnet import ResnetBlock3DCNN +except: + from resnet import ResnetBlock3DCNN + +from rotary_embedding_torch import RotaryEmbedding +from typing import Callable, Optional +from einops import rearrange, repeat + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def exists(x): + return x is not None + + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + # print('num head', heads) + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + # print(use_relative_position) + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.rotary_emb = RotaryEmbedding(min(32, dim_head)) + # # print(dim_head) + # # print(heads) + # # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265 + # self.max_position_embeddings = 32 + # self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head) + + # self.dropout = nn.Dropout(dropout) + + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + + # print('before reshpape query shape', query.shape) + dim = query.shape[-1] + if not self.use_relative_position: + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # print('after reshape query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if not self.use_relative_position: + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + + if attention_mask is not None: + # print('attention_mask', attention_mask.shape) + # print('attention_scores', attention_scores.shape) + # exit() + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + # print(attention_probs.shape) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + # print(attention_probs.shape) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + # print(hidden_states.shape) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # print(hidden_states.shape) + # exit() + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + # add 3D CNN for VSR + # if only_cross_attention == False: # x8 down + # self.resblock_temporal = ResnetBlock3DCNN(in_channels=in_channels, kernel=(3,3,3), temb_channels=None) + # else: + # self.resblock_temporal = ResnetBlock3DCNN(in_channels=in_channels, kernel=(5,1,1), temb_channels=None) + + self.resblock_temporal = ResnetBlock3DCNN(in_channels=in_channels, kernel=(3,1,1), temb_channels=None) + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + + # 3D CNN for VSR + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = self.resblock_temporal(hidden_states) + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w", f=video_length) + + residual = hidden_states + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + # print(only_cross_attention) + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_first_frame = use_first_frame # False for VSR + + # SC-Attn + if use_first_frame and only_cross_attention == False: + self.attn1 = SparseCausalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + # print(cross_attention_dim) + else: + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + + # Temp-Attn for VSR + self.attn_temporal = TemporalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + rotary_emb=rotary_emb, + ) + nn.init.zeros_(self.attn_temporal.to_out[0].weight.data) + self.norm_temporal = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: # Cross-Attention + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: # Self-Attention + if self.use_first_frame: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + + # Temporal-Attention for VSR + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length).contiguous() + norm_hidden_states = ( + self.norm_temporal(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temporal(hidden_states) + ) + hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class TemporalAttention(CrossAttention): + def __init__(self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + rotary_emb=None): + super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups) + # relative time positional embeddings + self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet + self.rotary_emb = rotary_emb + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device) + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + # reshape for adding time positional bais + query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + + # torch.baddbmm only accepte 3-D tensor + # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm + # attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2)) + if exists(self.rotary_emb): + query = self.rotary_emb.rotate_queries_or_keys(query) + key = self.rotary_emb.rotate_queries_or_keys(key) + + attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key) + # print('attention_scores shape', attention_scores.shape) + # print('time_rel_pos_bias shape', time_rel_pos_bias.shape) + # print('attention_mask shape', attention_mask.shape) + + attention_scores = attention_scores + time_rel_pos_bias + # print(attention_scores.shape) + + # bert from huggin face + # attention_scores = attention_scores / math.sqrt(self.dim_head) + + # # Normalize the attention scores to probabilities. + # attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + if attention_mask is not None: + # add attention mask + attention_scores = attention_scores + attention_mask + + # vdm + attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach() + + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + # print(attention_probs[0][0]) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + # hidden_states = torch.matmul(attention_probs, value) + hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value) + # print(hidden_states.shape) + # hidden_states = self.same_batch_dim_to_heads(hidden_states) + hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)') + # print(hidden_states.shape) + # exit() + return hidden_states + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/autoencoder_kl.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/autoencoder_kl.py new file mode 100644 index 0000000000000000000000000000000000000000..1983c3f1cefdbccf47659d79e4b4180fb9cc6455 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/autoencoder_kl.py @@ -0,0 +1,334 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput + +try: + from diffusers.utils import apply_forward_hook +except: + from diffusers.utils.accelerate_utils import apply_forward_hook + +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +@dataclass +class AutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class AutoencoderKL(ModelMixin, ConfigMixin): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow + the processing of larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a, b, blend_extent): + for y in range(min(a.shape[2], b.shape[2], blend_extent)): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a, b, blend_extent): + for x in range(min(a.shape[3], b.shape[3], blend_extent)): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: + different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + look of the output, but they should be much less noticeable. + x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r"""Decode a batch of images using a tiled decoder. + + Args: + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: + different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + look of the output, but they should be much less noticeable. + z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to + `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/clip.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..13ec05118f729cc27e6b7b4107a300ecfbdf2d77 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/clip.py @@ -0,0 +1,127 @@ +import numpy +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPTextModel +from diffusers import StableDiffusionUpscalePipeline + +""" +Will encounter following warning: +- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task +or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). +- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model +that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). + +https://github.com/CompVis/stable-diffusion/issues/97 +according to this issue, this warning is safe. + +This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. +You can safely ignore the warning, it is not an error. + +This clip usage is from U-ViT and same with Stable Diffusion. +""" + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, device="cuda", max_length=77): + super().__init__() + # self.tokenizer = CLIPTokenizer.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K') + # self.text_encoder = CLIPTextModel.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K') + # TBD: change to https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/config.json + # model_id = "stabilityai/stable-diffusion-x4-upscaler" # For VSR + # upscale_pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id) + upscale_pipeline = StableDiffusionUpscalePipeline.from_pretrained('./pretrained_models/upscaler4x') + self.tokenizer = upscale_pipeline.tokenizer + self.text_encoder = upscale_pipeline.text_encoder + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.text_encoder = self.text_encoder.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length, + return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.text_encoder(input_ids=tokens) + + # return outputs.last_hidden_state + return outputs[0] + + def encode(self, text): + return self(text) + + +class TextEmbedder(nn.Module): + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + def __init__(self, dropout_prob=0.1): + super().__init__() + self.text_encodder = FrozenCLIPEmbedder() + self.dropout_prob = dropout_prob + + def token_drop(self, text_prompts, force_drop_ids=None): + """ + Drops text to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = list(numpy.where(drop_ids, "None", text_prompts)) + # print(labels) + return labels + + def forward(self, text_prompts, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + text_prompts = self.token_drop(text_prompts, force_drop_ids) + embeddings = self.text_encodder(text_prompts) + return embeddings + + +if __name__ == '__main__': + + r""" + Returns: + + Examples from CLIPTextModel: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + text_encoder = TextEmbedder(dropout_prob=0.00001).to(device) + text_encoder1 = FrozenCLIPEmbedder().to(device) + + text_prompt = ["a photo of a cat", "a photo of a dog", 'a photo of a dog human'] + # text_prompt = ('None', 'None', 'None') + output = text_encoder(text_prompts=text_prompt, train=True) + output1 = text_encoder1(text_prompt) + # print(output) + print(output.shape) + print(output1.shape) + print((output==output1).all()) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/diffusers_attention.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/diffusers_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c104074eeab835cec0a9ce95f1b5753d05a67566 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/diffusers_attention.py @@ -0,0 +1,983 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm +from diffusers.models.attention_processor import Attention + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = in_channels is not None + self.is_input_vectorized = num_vector_embeds is not None + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized: + raise ValueError( + f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if self.is_input_continuous: + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (`int`): The number of channels in the input and output. + num_head_channels (`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + norm_num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + + self._use_memory_efficient_attention_xformers = False + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None): + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + scale = 1 / math.sqrt(self.channels / self.num_heads) + + query_proj = self.reshape_heads_to_batch_dim(query_proj) + key_proj = self.reshape_heads_to_batch_dim(key_proj) + value_proj = self.reshape_heads_to_batch_dim(value_proj) + + if self._use_memory_efficient_attention_xformers: + # Memory efficient attention + hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + hidden_states = hidden_states.to(query_proj.dtype) + else: + attention_scores = torch.baddbmm( + torch.empty( + query_proj.shape[0], + query_proj.shape[1], + key_proj.shape[1], + dtype=query_proj.dtype, + device=query_proj.device, + ), + query_proj, + key_proj.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + hidden_states = torch.bmm(attention_probs, value_proj) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + + # 2. Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None): + # 1. Self-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward( + self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.FloatTensor`, *optional*): + Optional attention mask to be applied in CrossAttention + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + # attention_mask is not used yet + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/pipeline_stable_diffusion_upscale_video_3d.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/pipeline_stable_diffusion_upscale_video_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e83ca6bf4c7cfa75c260e3d1cc088f8bdc5eb240 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/pipeline_stable_diffusion_upscale_video_3d.py @@ -0,0 +1,780 @@ + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, List, Optional, Union + +import numpy as np +import math +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import DDPMScheduler +# from diffusers.schedulers import DDIMScheduler +from diffusion.scheduling_ddim import DDIMScheduler + +from diffusers.utils import deprecate, is_accelerate_available, is_accelerate_version, logging + +try: + from diffusers.utils import randn_tensor +except: + from diffusers.utils.torch_utils import randn_tensor + +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from einops import rearrange + +# from datasets.data_utils import filter2D +# from datasets.degradations import random_mixed_kernels, bivariate_Gaussian + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin): + _optional_components = ["feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + low_res_scheduler: DDPMScheduler, + # scheduler: KarrasDiffusionSchedulers, + scheduler: DDIMScheduler, + feature_extractor: Optional[CLIPImageProcessor] = None, + max_noise_level: int = 350, + ): + super().__init__() + + if hasattr( + vae, "config" + ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate + is_vae_scaling_factor_set_to_0_08333 = ( + hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 + ) + if not is_vae_scaling_factor_set_to_0_08333: + deprecation_message = ( + "The configuration file of the vae does not contain `scaling_factor` or it is set to" + f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" + " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" + " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" + " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" + ) + deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) + vae.register_to_config(scaling_factor=0.08333) + # TODO: remove + print(f'=============vae.config.scaling_factor: {vae.config.scaling_factor}==================') + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.register_to_config(max_noise_level=max_noise_level) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def decode_latents_vsr(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = image.clamp(-1, 1).cpu() + return image + + def check_inputs( + self, + prompt, + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor + if isinstance(image, list) or isinstance(image, torch.Tensor): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents_3d(self, batch_size, num_channels_latents, seq_len, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, seq_len, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents_inversion(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + + image = image.to(device=device, dtype=dtype) + batch_size = batch_size * num_images_per_prompt + + b = image.shape[0] + image = rearrange(image, 'b c t h w -> (b t) c h w').contiguous() + image = F.interpolate(image, scale_factor=4, mode='bicubic') + image = image.to(dtype=torch.float32) + init_latents = self.vae.encode(image).latent_dist.sample(generator) + torch.cuda.empty_cache() + init_latents = rearrange(init_latents, '(b t) c h w -> b c t h w', b=b).contiguous() + + init_latents = self.vae.config.scaling_factor * init_latents + init_latents = init_latents.to(dtype=torch.float16) + + # add noise + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + # DEBUG + # init_latents = noise + print('timestep', timestep) + + # scale the initial noise by the standard deviation required by the scheduler + latents = init_latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + ```py + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + >>> from diffusers import StableDiffusionUpscalePipeline + >>> import torch + + >>> # load model and scheduler + >>> model_id = "stabilityai/stable-diffusion-x4-upscaler" + >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained( + ... model_id, revision="fp16", torch_dtype=torch.float16 + ... ) + >>> pipeline = pipeline.to("cuda") + + >>> # let's download an image + >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" + >>> response = requests.get(url) + >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> low_res_img = low_res_img.resize((128, 128)) + >>> prompt = "a white cat" + + >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] + >>> upscaled_image.save("upsampled_cat.png") + ``` + """ + + # 1. Check inputs + self.check_inputs( + prompt, + image, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess image + # image = preprocess(image) + image = image.to(dtype=prompt_embeds.dtype, device=device) + + # 5. Add noise to image + noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + image = self.low_res_scheduler.add_noise(image, noise, noise_level) + # image = image.clamp(-1, 1) + + # debug + # image = rearrange(image, 'b c t h w -> (b t) c h w').contiguous().cpu() + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = torch.cat([image] * batch_multiplier * num_images_per_prompt) + # TODO: + # noise_level = noise_level*0 + noise_level = torch.cat([noise_level] * image.shape[0]) + + ####################### Random Noise ######################## + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + seq_len, height, width = image.shape[2:] + # TODO: for downsample_2x + # height, width = height//2, width//2 + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents_3d( + batch_size * num_images_per_prompt, + num_channels_latents, + seq_len, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) # b c t h w + # print('latents', latents.shape) + + ####################### Random Noise + Latent ######################## + # # 5. Prepare timesteps + # self.scheduler.set_timesteps(num_inference_steps, device=device) + # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength=1, device=device) + # # DEBUG + # # timesteps = self.scheduler.timesteps + # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # # 6. Prepare latent variables + # # b c t h w + # b = image.shape[0] + # num_channels_latents = self.vae.config.latent_channels + # latents = self.prepare_latents_inversion( + # image[:b//2], + # latent_timestep, + # batch_size, + # num_images_per_prompt, + # prompt_embeds.dtype, + # device, + # generator, + # ) + # print('latents', latents.shape) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + torch.cuda.empty_cache() # delete for VSR + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + #latent_model_input = torch.cat([latent_model_input, image], dim=1) + # print(f'========== latent_model_input: {latent_model_input.shape} ============') + # print(f'========== image: {image.shape} ============') + noise_pred = self.unet( + latent_model_input, t, image, encoder_hidden_states=prompt_embeds, class_labels=noise_level + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + del latent_model_input, noise_pred + + + # 10. Post-processing + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + # TODO(Patrick, William) - clean up when attention is refactored + use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") + use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if not use_torch_2_0_attn and not use_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + # 11. Convert to frames + short_seq = 4 + # b c t h w + latents = rearrange(latents, 'b c t h w -> (b t) c h w').contiguous() + if latents.shape[0] > short_seq: # for VSR + image = [] + for start_f in range(0, latents.shape[0], short_seq): + torch.cuda.empty_cache() # delete for VSR + end_f = min(latents.shape[0], start_f + short_seq) + image_ = self.decode_latents_vsr(latents[start_f:end_f]) + image.append(image_) + del image_ + image = torch.cat(image, dim=0) + else: + image = self.decode_latents_vsr(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/resnet.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..490fb323dc43c67acce260a2ff9a7e80f021f76c --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/resnet.py @@ -0,0 +1,316 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class ResnetBlock3DCNN(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + kernel=(3,1,1), + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + padding = ((kernel[i]-1)//2 for i in range(3)) + self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel, stride=(1,1,1), padding=padding) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3,1,1), stride=(1,1,1), padding=(1,0,0)) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=(1,1,1), stride=(1,1,1), padding=(0,0,0)) + + def forward(self, input_tensor, temb=None): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/temporal_module.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/temporal_module.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8ffb7aa8da232507831bf6fb241340b0f8bc9d --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/temporal_module.py @@ -0,0 +1,684 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import numpy as np +import torch.nn.functional as F +from torch import nn +import torchvision +# from torch_utils.ops import grid_sample_gradfix + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward +# from diffusers.models.attention_processor import Attention + +try: + from .diffusers_attention import CrossAttention + from .resnet import Downsample3D, Upsample3D, InflatedConv3d, ResnetBlock3D, ResnetBlock3DCNN + +except: + from diffusers_attention import CrossAttention + from resnet import Downsample3D, Upsample3D, InflatedConv3d, ResnetBlock3D, ResnetBlock3DCNN + +from einops import rearrange, repeat +import math + +import pdb + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def grid_sample_align(input, grid): + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=True) + + +@dataclass +class TemporalTransformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class EmptyTemporalModule3D(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, condition_video=None, encoder_hidden_states=None, timesteps=None, temb=None, attention_mask=None): + return hidden_states + + +class TemporalModule3D(nn.Module): + def __init__( + self, + in_channels=None, + out_channels=None, + + num_attention_layers=None, + num_attention_head=8, + attention_head_dim=None, + cross_attention_dim=768, + temb_channels=512, + + dropout=0., + attention_bias=False, + activation_fn="geglu", + only_cross_attention=False, + upcast_attention=False, + + norm_num_groups=8, + use_linear_projection=True, + use_scale_shift=False, # set True always produce nan loss, I don't know why + + attention_block_types: Tuple[str]=None, + cross_frame_attention_mode=None, + temporal_shift_fold_div=None, + temporal_shift_direction=None, + + use_dcn_warpping=None, + use_deformable_conv=None, + + attention_dim_div: int = None, + video_condition=False, + ): + super().__init__() + assert len(attention_block_types) == 2 + + self.use_scale_shift = use_scale_shift + self.video_condition = video_condition + + self.non_linearity = nn.SiLU() + + # 1. 3d cnn + if self.video_condition: + video_condition_dim = int(out_channels//4) + self.v_cond_conv = ResnetBlock3D(in_channels=3, out_channels=video_condition_dim, temb_channels=temb_channels, groups=3, groups_out=32) + self.resblocks_3d_t = ResnetBlock3DCNN(in_channels=in_channels+video_condition_dim, out_channels=in_channels, kernel=(5,1,1), temb_channels=temb_channels) + else: + self.resblocks_3d_t = ResnetBlock3DCNN(in_channels=in_channels, out_channels=in_channels, kernel=(5,1,1), temb_channels=temb_channels) + + self.resblocks_3d_s = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, groups=32, groups_out=32) + + # 2. transformer blocks + if not (attention_block_types[0]=='' and attention_block_types[1]==''): + attentions = TemporalTransformer3DModel( + num_attention_heads=num_attention_head, + attention_head_dim=attention_head_dim if attention_head_dim is not None else in_channels // num_attention_head // attention_dim_div, + + in_channels=in_channels, + num_layers=num_attention_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + activation_fn=activation_fn, + num_embeds_ada_norm=1000, # adaptive norm for timestep embedding injection + use_linear_projection=use_linear_projection, + + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_shift_fold_div=temporal_shift_fold_div, + temporal_shift_direction=temporal_shift_direction, + + use_dcn_warpping=use_dcn_warpping, + use_deformable_conv=use_deformable_conv, + ) + self.attentions = nn.ModuleList([attentions]) + + if use_scale_shift: + self.scale_shift_conv = zero_module(InflatedConv3d(in_channels=in_channels, out_channels=in_channels * 2, kernel_size=1, stride=1, padding=0)) + else: + self.shift_conv = zero_module(InflatedConv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1, padding=0)) + + + def forward(self, hidden_states, condition_video=None, encoder_hidden_states=None, timesteps=None, temb=None, attention_mask=None): + input_tensor = hidden_states + + if self.video_condition: + # obtain video attention + assert condition_video is not None + if isinstance(condition_video, dict): + condition_video = condition_video[hidden_states.shape[-1]] + hidden_condition = self.v_cond_conv(condition_video, temb) + hidden_states = torch.cat([hidden_states, hidden_condition], dim=1) + + # 3DCNN + hidden_states = self.resblocks_3d_t(hidden_states, temb) + hidden_states = self.resblocks_3d_s(hidden_states, temb) + + if hasattr(self, "attentions"): + for attn in self.attentions: + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timesteps).sample + + if self.use_scale_shift: + hidden_states = self.scale_shift_conv(hidden_states) + scale, shift = torch.chunk(hidden_states, chunks=2, dim=1) + hidden_states = (1 + scale) * input_tensor + shift + else: + hidden_states = self.shift_conv(hidden_states) + hidden_states = input_tensor + hidden_states + + return hidden_states + + +class TemporalTransformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads=None, + attention_head_dim=None, + in_channels=None, + num_layers=None, + dropout=None, + norm_num_groups=None, + cross_attention_dim=None, + attention_bias=None, + activation_fn=None, + num_embeds_ada_norm=None, + use_linear_projection=None, + only_cross_attention=None, + upcast_attention=None, + + attention_block_types=None, + cross_frame_attention_mode=None, + temporal_shift_fold_div=None, + temporal_shift_direction=None, + + use_dcn_warpping=None, + use_deformable_conv=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_shift_fold_div=temporal_shift_fold_div, + temporal_shift_direction=temporal_shift_direction, + + use_dcn_warpping=use_dcn_warpping, + use_deformable_conv=use_deformable_conv, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + if encoder_hidden_states is not None: + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return TemporalTransformer3DModelOutput(sample=output) + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim=None, + num_attention_heads=None, + attention_head_dim=None, + dropout=None, + cross_attention_dim=None, + activation_fn=None, + num_embeds_ada_norm=None, + attention_bias=None, + only_cross_attention=None, + upcast_attention=None, + + attention_block_types=None, + cross_frame_attention_mode=None, + temporal_shift_fold_div=None, + temporal_shift_direction=None, + + use_dcn_warpping=None, + use_deformable_conv=None, + ): + super().__init__() + assert len(attention_block_types) == 2 + + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_dcn_warpping = use_dcn_warpping + + # 1. Spatial-Attn (self) + if not attention_block_types[0] == '': + self.attn_spatial = VersatileSelfAttention( + attention_mode=attention_block_types[0], + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_shift_fold_div=temporal_shift_fold_div, + temporal_shift_direction=temporal_shift_direction, + ) + nn.init.zeros_(self.attn_spatial.to_out[0].weight.data) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # 2. Temporal-Attn (self) + self.attn_temporal = VersatileSelfAttention( + attention_mode=attention_block_types[1], + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_shift_fold_div=temporal_shift_fold_div, + temporal_shift_direction=temporal_shift_direction, + ) + nn.init.zeros_(self.attn_temporal.to_out[0].weight.data) + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + self.dcn_module = WarpModule( + in_channels=dim, + use_deformable_conv=use_deformable_conv, + ) if use_dcn_warpping else None + + # 3. Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + if hasattr(self, "attn_spatial"): + self.attn_spatial._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # 1. Spatial-Attention + if hasattr(self, "attn_spatial") and hasattr(self, "norm1"): + norm_hidden_states = self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + hidden_states = self.attn_spatial(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + # 2. Temporal-Attention + norm_hidden_states = self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + if not self.use_dcn_warpping: + hidden_states = self.attn_temporal(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + else: + hidden_states = self.dcn_module( + hidden_states, + offset_hidden_states=self.attn_temporal(norm_hidden_states, attention_mask=attention_mask, video_length=video_length), + ) + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class VersatileSelfAttention(CrossAttention): + def __init__( + self, + attention_mode=None, + cross_frame_attention_mode=None, + temporal_shift_fold_div=None, + temporal_shift_direction=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + *args, **kwargs + ): + super().__init__(*args, **kwargs) + assert attention_mode in ("Temporal", "Spatial", "CrossFrame", "SpatialTemporalShift", None) + assert cross_frame_attention_mode in ("0_i-1", "i-1_i", "0_i-1_i", "i-1_i_i+1", None) + + self.attention_mode = attention_mode + + self.cross_frame_attention_mode = cross_frame_attention_mode + + self.temporal_shift_fold_div = temporal_shift_fold_div + self.temporal_shift_direction = temporal_shift_direction + + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + dropout=0., + max_len=temporal_position_encoding_max_len + ) if temporal_position_encoding else None + + def temporal_token_concat(self, tensor, video_length): + # print("### temporal token concat") + current_frame_index = torch.arange(video_length) + former_frame_index = current_frame_index - 1 + former_frame_index[0] = 0 + + later_frame_index = current_frame_index + 1 + later_frame_index[-1] = -1 + + # (b f) d c + tensor = rearrange(tensor, "(b f) d c -> b f d c", f=video_length) + + if self.cross_frame_attention_mode == "0_i-1": + tensor = torch.cat([tensor[:, [0] * video_length], tensor[:, former_frame_index]], dim=2) + elif self.cross_frame_attention_mode == "i-1_i": + tensor = torch.cat([tensor[:, former_frame_index], tensor[:, current_frame_index]], dim=2) + elif self.cross_frame_attention_mode == "0_i-1_i": + tensor = torch.cat([tensor[:, [0] * video_length], tensor[:, former_frame_index], tensor[:, current_frame_index]], dim=2) + elif self.cross_frame_attention_mode == "i-1_i_i+1": + tensor = torch.cat([tensor[:, former_frame_index], tensor[:, current_frame_index], tensor[:, later_frame_index]], dim=2) + else: + raise NotImplementedError + + tensor = rearrange(tensor, "b f d c -> (b f) d c") + return tensor + + def temporal_shift(self, tensor, video_length): + # print("### temporal shift") + # (b f) d c + tensor = rearrange(tensor, "(b f) d c -> b f d c", f=video_length) + n_channels = tensor.shape[-1] + fold = n_channels // self.temporal_shift_fold_div + + if self.temporal_shift_direction != "right": + raise NotImplementedError + + tensor_out = torch.zeros_like(tensor) + tensor_out[:, 1:, :, :fold] = tensor[:, :-1, :, :fold] + tensor_out[:, :, :, fold:] = tensor[:, :, :, fold:] + + tensor_out = rearrange(tensor_out, "b f d c -> (b f) d c") + return tensor_out + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + # pdb.set_trace() + batch_size, sequence_length, _ = hidden_states.shape + assert encoder_hidden_states is None + + # (b f) d c + if self.attention_mode == "Temporal": + # print("### temporal reshape") + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if self.attention_mode == "SpatialTemporalShift": + key = self.temporal_shift(key, video_length=video_length) + value = self.temporal_shift(value, video_length=video_length) + elif self.attention_mode == "CrossFrame": + key = self.temporal_token_concat(key, video_length=video_length) + value = self.temporal_token_concat(value, video_length=video_length) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + if self.attention_mode == "Temporal": + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class WarpModule(nn.Module): + def __init__( + self, + in_channels=None, + use_deformable_conv=None, + ): + super().__init__() + self.use_deformable_conv = use_deformable_conv + + self.conv = None + self.dcn_weight = None + if use_deformable_conv: + self.conv = nn.Conv2d(in_channels*2, 27, kernel_size=3, stride=1, padding=1) + self.dcn_weight = nn.Parameter(torch.randn(in_channels, in_channels, 3, 3) / np.sqrt(in_channels * 3 * 3)) + self.alpha = nn.Parameter(torch.zeros(1, in_channels, 1, 1)) + else: + self.conv = zero_module(nn.Conv2d(in_channels, 2, kernel_size=3, stride=1, padding=1)) + + def forward(self, hidden_states, offset_hidden_states): + # (b f) d c + spatial_dim = hidden_states.shape[1] + size = int(spatial_dim ** 0.5) + assert size ** 2 == spatial_dim + + hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=size) + offset_hidden_states = rearrange(offset_hidden_states, "b (h w) c -> b c h w", h=size) + + concat_hidden_states = torch.cat([hidden_states, offset_hidden_states], dim=1) + + input_tensor = hidden_states + if self.use_deformable_conv: + offset_x, offset_y, offsets_mask = torch.chunk(self.conv(concat_hidden_states), chunks=3, dim=1) + offsets_mask = offsets_mask.sigmoid() * 2 + offsets = torch.cat([offset_x, offset_y], dim=1) + hidden_states = torchvision.ops.deform_conv2d( + hidden_states, + offset=offsets, + weight=self.dcn_weight, + mask=offsets_mask, + padding=1, + ) + hidden_states = self.alpha * hidden_states + input_tensor + + else: + offsets = self.conv(concat_hidden_states) + hidden_states = self.optical_flow_warping(hidden_states, offsets) + + hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") + return hidden_states + + @staticmethod + def optical_flow_warping(x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + pad_mode (optional): ref to https://pytorch.org/docs/stable/nn.functional.html#grid-sample + "zeros": use 0 for out-of-bound grid locations, + "border": use border values for out-of-bound grid locations + """ + dtype = x.dtype + if dtype != torch.float32: + x = x.to(torch.float32) + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1, -1).repeat(H, 1) + yy = torch.arange(0, H).view(-1, 1).repeat(1, W) + xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) + yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) + grid = torch.cat((xx, yy), 1).float().to(flo.device) + + vgrid = grid + flo + + # scale grid to [-1,1] + vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 + vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 + + vgrid = vgrid.permute(0, 2, 3, 1) + # output = grid_sample_gradfix.grid_sample_align(x, vgrid) + output = grid_sample_align(x, vgrid) + #output = torch.nn.functional.grid_sample(x, vgrid, padding_mode='zeros', mode='bilinear', align_corners=True) + + mask = torch.ones_like(x) + # mask = grid_sample_gradfix.grid_sample_align(mask, vgrid) + mask = grid_sample_align(x, vgrid) + #mask = torch.nn.functional.grid_sample(mask, vgrid, padding_mode='zeros', mode='bilinear', align_corners=True) + + mask[mask < 0.9999] = 0 + mask[mask > 0] = 1 + results = output * mask + if dtype != torch.float32: + results = results.to(dtype) + return results + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + timestep = repeat(timestep, "b -> (b r)", r=x.shape[0] // timestep.shape[0]) + + emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1) # (b f) 1 2d + scale, shift = torch.chunk(emb, 2, dim=-1) + x = self.norm(x) * (1 + scale) + shift + return x + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/unet.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..f30bc66acd049728a20fc97c2d56c6edea09df02 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/unet.py @@ -0,0 +1,654 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import json +import math +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import functional as F +import einops + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.utils import BaseOutput, logging + +try: + from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from .resnet import InflatedConv3d + from .temporal_module import TemporalModule3D, EmptyTemporalModule3D +except: + from unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from resnet import InflatedConv3d + from temporal_module import TemporalModule3D, EmptyTemporalModule3D + +from rotary_embedding_torch import RotaryEmbedding + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DVSRModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + ### Temporal Module Additional Kwargs ### + down_temporal_idx = (0,1,2), + mid_temporal = False, + up_temporal_idx = (0,1,2), + video_condition = True, + temporal_module_config = None, + + sample_size: Optional[int] = None, # 80 + in_channels: int = 7, + out_channels: int = 4, + center_input_sample: bool = False, + max_noise_level: int = 350, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + attention_head_dim: Union[int, Tuple[int]] = 8, + block_out_channels: Tuple[int] = ( + 256, + 512, + 512, + 1024 + ), + down_block_types: Tuple[str] = ( + "DownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D" + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "UpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = ( + True, + True, + True, + False + ), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = 1000, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=1) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) # VSR for noise level + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.video_condition = video_condition + + # Temporal Modules + self.down_temporal_blocks = nn.ModuleList([]) + self.mid_temporal_block = None + self.up_temporal_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + + self.temporal_rotary_emb = RotaryEmbedding(32) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=self.temporal_rotary_emb, + ) + self.down_blocks.append(down_block) + + # Down Sample Temporal Modules + down_temporal_block = TemporalModule3D( + in_channels=output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + video_condition=video_condition, + **temporal_module_config, + ) if i in down_temporal_idx else EmptyTemporalModule3D() + self.down_temporal_blocks.append(down_temporal_block) + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=self.temporal_rotary_emb, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + self.mid_temporal_block = TemporalModule3D( + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + video_condition=video_condition, + **temporal_module_config, + ) if mid_temporal else EmptyTemporalModule3D() + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=self.temporal_rotary_emb, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + up_temporal_block = TemporalModule3D( + in_channels=output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + video_condition=video_condition, + **temporal_module_config, + ) if i in up_temporal_idx else EmptyTemporalModule3D() + self.up_temporal_blocks.append(up_temporal_block) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + low_res: torch.FloatTensor, + # encoder_hidden_states: torch.Tensor, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = 20, + low_res_clean: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): # -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, seq_length, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + class_labels: noise level + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if self.video_condition: + low_res_dict = {} + low_res_dict[low_res.shape[-1]] = low_res + for s in [1/2., 1/4., 1/8.]: + low_res_ds = F.interpolate(low_res, scale_factor=(1, s, s), mode='area') + low_res_dict[low_res_ds.shape[-1]] = low_res_ds + else: + low_res_dict = None + + sample = torch.cat([sample, low_res], dim=1) # concat on C: 4+3=7 + + #print(f'==============={sample.shape}================') + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + # check noise level + if torch.any(class_labels > self.config.max_noise_level): + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {class_labels}") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block, down_temporal_block in zip(self.down_blocks, self.down_temporal_blocks): + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 1. temporal modeling during down sample + sample = down_temporal_block( + hidden_states=sample, + condition_video=low_res_dict, + encoder_hidden_states=encoder_hidden_states, + timesteps=timesteps, + temb=emb, + ) + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + # 2. temporal modeling at mid block + sample = self.mid_temporal_block( + hidden_states=sample, + condition_video=low_res_dict, + encoder_hidden_states=encoder_hidden_states, + timesteps=timesteps, + temb=emb, + ) + + # up + for i, (upsample_block, up_temporal_block) in enumerate(zip(self.up_blocks, self.up_temporal_blocks)): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 3. temporal modeling during up sample + sample = up_temporal_block( + hidden_states=sample, + condition_video=low_res_dict, + encoder_hidden_states=encoder_hidden_states, + timesteps=timesteps, + temb=emb, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + # print(sample.shape) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + def forward_with_cfg(self, + x, + t, + low_res, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = 20, + cfg_scale=4.0, + use_fp16=False): + """ + Forward, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + if use_fp16: + combined = combined.to(dtype=torch.float16) + model_out = self.forward(combined, t, low_res, encoder_hidden_states, class_labels).sample + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = model_out[:, :4], model_out[:, 4:] + # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + + @classmethod + def from_pretrained_2d(cls, config_path, pretrained_model_path): + if not os.path.isfile(config_path): + raise RuntimeError(f"{config_path} does not exist") + with open(config_path, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + freeze_pretrained_2d_upsampler = config["freeze_pretrained_2d_upsampler"] + + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + for k, v in model.state_dict().items(): + if 'temporal' in k: + print(f'New layers: {k}') + state_dict.update({k: v}) + + model.load_state_dict(state_dict, strict=True) + + if freeze_pretrained_2d_upsampler: + print("Freeze pretrained 2d upsampler!") + for k, v in model.named_parameters(): + if not 'temporal' in k: + v.requires_grad = False + return model + +if __name__ == '__main__': + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + config_path = "./configs/unet_3d_config.json" + # pretrained_model_path = "./pretrained_models/unet_diffusion_pytorch_model.bin" + # unet = UNet3DVSRModel.from_pretrained_2d(config_path, pretrained_model_path).to(device) diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/unet_blocks.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..066637ff78439e61864e5893a122b32f07e58f89 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/unet_blocks.py @@ -0,0 +1,629 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +from torch import nn + +try: + from .attention import Transformer3DModel + from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +except: + from attention import Transformer3DModel + from resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=None, +): + # print(down_block_type) + # print(use_first_frame) + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=None, + ): + super().__init__() + resnets = [] + attentions = [] + + # print(use_first_frame) + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=None + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/upscaling.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/upscaling.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f62ecd358a839b9a7c6a2d0e403ed9150ed2d3 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/upscaling.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial +from inspect import isfunction + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def make_beta_schedule(n_timestep, linear_start=1e-4, linear_end=2e-2): + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + return betas.numpy() + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, timesteps=1000, linear_start=1e-4, linear_end=2e-2): + betas = make_beta_schedule(timesteps, linear_start=linear_start, linear_end=linear_end) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/utils.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76e205d1709568885e58b56888f62c8805ad4f91 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/models/utils.py @@ -0,0 +1,215 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch + +import numpy as np +import torch.nn as nn + +from einops import repeat + + +################################################################################# +# Unet Utils # +################################################################################# + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# class HybridConditioner(nn.Module): + +# def __init__(self, c_concat_config, c_crossattn_config): +# super().__init__() +# self.concat_conditioner = instantiate_from_config(c_concat_config) +# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + +# def forward(self, c_concat, c_crossattn): +# c_concat = self.concat_conditioner(c_concat) +# c_crossattn = self.crossattn_conditioner(c_crossattn) +# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params \ No newline at end of file diff --git a/src/videogen_hub/pipelines/lavie/lavie_src/vsr/sample.py b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cbb0d79866fc75bd68fe98989c5bd690f8aba1 --- /dev/null +++ b/src/videogen_hub/pipelines/lavie/lavie_src/vsr/sample.py @@ -0,0 +1,151 @@ +import io +import os +import sys +import argparse +o_path = os.getcwd() +sys.path.append(o_path) + +import torch +import time +import json +import numpy as np +import imageio +import torchvision +from einops import rearrange + +from models.autoencoder_kl import AutoencoderKL +from models.unet import UNet3DVSRModel +from models.pipeline_stable_diffusion_upscale_video_3d import StableDiffusionUpscalePipeline +from diffusers import DDIMScheduler +from omegaconf import OmegaConf + + +def main(args): + + device = "cuda" + + # ---------------------- load models ---------------------- + pipeline = StableDiffusionUpscalePipeline.from_pretrained(args.pretrained_path + '/stable-diffusion-x4-upscaler', torch_dtype=torch.float16) + + # vae + pipeline.vae = AutoencoderKL.from_config("configs/vae_config.json") + pretrained_model = args.pretrained_path + "/stable-diffusion-x4-upscaler/vae/diffusion_pytorch_model.bin" + pipeline.vae.load_state_dict(torch.load(pretrained_model, map_location="cpu")) + + # unet + config_path = "./configs/unet_3d_config.json" + with open(config_path, "r") as f: + config = json.load(f) + config['video_condition'] = False + pipeline.unet = UNet3DVSRModel.from_config(config) + + pretrained_model = args.ckpt_path + checkpoint = torch.load(pretrained_model, map_location="cpu")['ema'] + + pipeline.unet.load_state_dict(checkpoint, True) + pipeline.unet = pipeline.unet.half() + pipeline.unet.eval() # important! + + # DDIMScheduler + with open(args.pretrained_path + '/stable-diffusion-x4-upscaler/scheduler/scheduler_config.json', "r") as f: + config = json.load(f) + config["beta_schedule"] = "linear" + pipeline.scheduler = DDIMScheduler.from_config(config) + + pipeline = pipeline.to("cuda") + + # ---------------------- load user's prompt ---------------------- + # input + video_root = args.input_path + video_list = sorted(os.listdir(video_root)) + print('video num:', len(video_list)) + + # output + save_root = args.output_path + os.makedirs(save_root, exist_ok=True) + + # inference params + noise_level = args.noise_level + guidance_scale = args.guidance_scale + num_inference_steps = args.inference_steps + + # ---------------------- start inferencing ---------------------- + for i, video_name in enumerate(video_list): + video_name = video_name.replace('.mp4', '') + print(f'[{i+1}/{len(video_list)}]: ', video_name) + + lr_path = f"{video_root}/{video_name}.mp4" + save_path = f"{save_root}/{video_name}.mp4" + + prompt = video_name + print('Prompt: ', prompt) + + negative_prompt = "blur, worst quality" + + vframes, aframes, info = torchvision.io.read_video(filename=lr_path, pts_unit='sec', output_format='TCHW') # RGB + vframes = vframes / 255. + vframes = (vframes - 0.5) * 2 # T C H W [-1, 1] + t, _, h, w = vframes.shape + vframes = vframes.unsqueeze(dim=0) # 1 T C H W + vframes = rearrange(vframes, 'b t c h w -> b c t h w').contiguous() # 1 C T H W + print('Input_shape:', vframes.shape, 'Noise_level:', noise_level, 'Guidance_scale:', guidance_scale) + + fps = info['video_fps'] + generator = torch.Generator(device=device).manual_seed(10) + + torch.cuda.synchronize() + start_time = time.time() + + with torch.no_grad(): + short_seq = 8 + vframes_seq = vframes.shape[2] + if vframes_seq > short_seq: # for VSR + upscaled_video_list = [] + for start_f in range(0, vframes_seq, short_seq): + print(f'Processing: [{start_f}-{start_f + short_seq}/{vframes_seq}]') + torch.cuda.empty_cache() # delete for VSR + end_f = min(vframes_seq, start_f + short_seq) + + upscaled_video_ = pipeline( + prompt, + image=vframes[:,:,start_f:end_f], + generator=generator, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + noise_level=noise_level, + negative_prompt=negative_prompt, + ).images # T C H W [-1, 1] + upscaled_video_list.append(upscaled_video_) + upscaled_video = torch.cat(upscaled_video_list, dim=0) + else: + upscaled_video = pipeline( + prompt, + image=vframes, + generator=generator, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + noise_level=noise_level, + negative_prompt=negative_prompt, + ).images # T C H W [-1, 1] + + torch.cuda.synchronize() + run_time = time.time() - start_time + + print('Output:', upscaled_video.shape) + + # save video + upscaled_video = (upscaled_video / 2 + 0.5).clamp(0, 1) * 255 + upscaled_video = upscaled_video.permute(0, 2, 3, 1).to(torch.uint8) + upscaled_video = upscaled_video.numpy().astype(np.uint8) + imageio.mimwrite(save_path, upscaled_video, fps=fps, quality=9) # Highest quality is 10, lowest is 0 + + print(f'Save upscaled video "{video_name}" in {save_path}, time (sec): {run_time} \n') + print(f'\nAll results are saved in {save_path}') + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="") + args = parser.parse_args() + + main(OmegaConf.load(args.config)) diff --git a/src/videogen_hub/pipelines/opensora/CONTRIBUTING.md b/src/videogen_hub/pipelines/opensora/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..2acbec41b910e2906e859e8fceee5e48d482fc15 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/CONTRIBUTING.md @@ -0,0 +1,100 @@ +# Contributing + +The Open-Sora project welcomes any constructive contribution from the community and the team is more than willing to work on problems you have encountered to make it a better project. + +## Development Environment Setup + +To contribute to Open-Sora, we would like to first guide you to set up a proper development environment so that you can better implement your code. You can install this library from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without re-installation. + +You can refer to the [Installation Section](./README.md#installation) and replace `pip install -v .` with `pip install -v -e .`. + +### Code Style + +We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below. + +```shell +# these commands are executed under the Open-Sora directory +pip install pre-commit +pre-commit install +``` + +Code format checking will be automatically executed when you commit your changes. + +## Contribution Guide + +You need to follow these steps below to make contribution to the main repository via pull request. You can learn about the details of pull request [here](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests). + +### 1. Fork the Official Repository + +Firstly, you need to visit the [Open-Sora repository](https://github.com/hpcaitech/Open-Sora) and fork into your own account. The `fork` button is at the right top corner of the web page alongside with buttons such as `watch` and `star`. + +Now, you can clone your own forked repository into your local environment. + +```shell +git clone https://github.com//Open-Sora.git +``` + +### 2. Configure Git + +You need to set the official repository as your upstream so that you can synchronize with the latest update in the official repository. You can learn about upstream [here](https://www.atlassian.com/git/tutorials/git-forks-and-upstreams). + +Then add the original repository as upstream + +```shell +cd Open-Sora +git remote add upstream https://github.com/hpcaitech/Open-Sora.git +``` + +you can use the following command to verify that the remote is set. You should see both `origin` and `upstream` in the output. + +```shell +git remote -v +``` + +### 3. Synchronize with Official Repository + +Before you make changes to the codebase, it is always good to fetch the latest updates in the official repository. In order to do so, you can use the commands below. + +```shell +git fetch upstream +git checkout main +git merge upstream/main +git push origin main +``` + +### 5. Create a New Branch + +You should not make changes to the `main` branch of your forked repository as this might make upstream synchronization difficult. You can create a new branch with the appropriate name. General branch name format should start with `hotfix/` and `feature/`. `hotfix` is for bug fix and `feature` is for addition of a new feature. + +```shell +git checkout -b +``` + +### 6. Implementation and Code Commit + +Now you can implement your code change in the source code. Remember that you installed the system in development, thus you do not need to uninstall and install to make the code take effect. The code change will be reflected in every new PyThon execution. +You can commit and push the changes to your local repository. The changes should be kept logical, modular and atomic. + +```shell +git add -A +git commit -m "" +git push -u origin +``` + +### 7. Open a Pull Request + +You can now create a pull request on the GitHub webpage of your repository. The source branch is `` of your repository and the target branch should be `main` of `hpcaitech/Open-Sora`. After creating this pull request, you should be able to see it [here](https://github.com/hpcaitech/Open-Sora/pulls). + +The Open-Sora team will review your code change and merge your code if applicable. + +## FQA + +1. `pylint` cannot recognize some members: + +Add this into your `settings.json` in VSCode: + +```json +"pylint.args": [ + "--generated-members=numpy.* ,torch.*,cv2.*", +], +``` diff --git a/src/videogen_hub/pipelines/opensora/LICENSE b/src/videogen_hub/pipelines/opensora/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..553df3ce196c99c4a17839c1275772e8fbbd2154 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/LICENSE @@ -0,0 +1,679 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ========================================================================= + This project is inspired by the listed projects and is subject to the following licenses: + + 1. Latte (https://github.com/Vchitect/Latte/blob/main/LICENSE) + + Copyright 2024 Latte + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + 2. PixArt-alpha (https://github.com/PixArt-alpha/PixArt-alpha/blob/master/LICENSE) + + Copyright (C) 2024 PixArt-alpha/PixArt-alpha + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + + 3. dpm-solver (https://github.com/LuChengTHU/dpm-solver/blob/main/LICENSE) + + MIT License + + Copyright (c) 2022 Cheng Lu + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + 4. DiT (https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt) + + Attribution-NonCommercial 4.0 International + + ======================================================================= + + Creative Commons Corporation ("Creative Commons") is not a law firm and + does not provide legal services or legal advice. Distribution of + Creative Commons public licenses does not create a lawyer-client or + other relationship. Creative Commons makes its licenses and related + information available on an "as-is" basis. Creative Commons gives no + warranties regarding its licenses, any material licensed under their + terms and conditions, or any related information. Creative Commons + disclaims all liability for damages resulting from their use to the + fullest extent possible. + + Using Creative Commons Public Licenses + + Creative Commons public licenses provide a standard set of terms and + conditions that creators and other rights holders may use to share + original works of authorship and other material subject to copyright + and certain other rights specified in the public license below. The + following considerations are for informational purposes only, are not + exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + + ======================================================================= + + Creative Commons Attribution-NonCommercial 4.0 International Public + License + + By exercising the Licensed Rights (defined below), You accept and agree + to be bound by the terms and conditions of this Creative Commons + Attribution-NonCommercial 4.0 International Public License ("Public + License"). To the extent this Public License may be interpreted as a + contract, You are granted the Licensed Rights in consideration of Your + acceptance of these terms and conditions, and the Licensor grants You + such rights in consideration of benefits the Licensor receives from + making the Licensed Material available under these terms and + conditions. + + Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + Section 3 -- License Conditions. + + Your exercise of the Licensed Rights is expressly made subject to the + following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + + Section 4 -- Sui Generis Database Rights. + + Where the Licensed Rights include Sui Generis Database Rights that + apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + + For the avoidance of doubt, this Section 4 supplements and does not + replace Your obligations under this Public License where the Licensed + Rights include other Copyright and Similar Rights. + + Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + + ======================================================================= + + Creative Commons is not a party to its public + licenses. Notwithstanding, Creative Commons may elect to apply one of + its public licenses to material it publishes and in those instances + will be considered the “Licensor.” The text of the Creative Commons + public licenses is dedicated to the public domain under the CC0 Public + Domain Dedication. Except for the limited purpose of indicating that + material is shared under a Creative Commons public license or as + otherwise permitted by the Creative Commons policies published at + creativecommons.org/policies, Creative Commons does not authorize the + use of the trademark "Creative Commons" or any other trademark or logo + of Creative Commons without its prior written consent including, + without limitation, in connection with any unauthorized modifications + to any of its public licenses or any other arrangements, + understandings, or agreements concerning use of licensed material. For + the avoidance of doubt, this paragraph does not form part of the + public licenses. + + Creative Commons may be contacted at creativecommons.org. + + 5. OpenDiT (https://github.com/NUS-HPC-AI-Lab/OpenDiT/blob/master/LICENSE) + + Copyright OpenDiT + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/videogen_hub/pipelines/opensora/README.md b/src/videogen_hub/pipelines/opensora/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9c96bf1e3d5cd3b84748dd522a56b2bab10fb13b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/README.md @@ -0,0 +1,382 @@ +

+ +

+
+ + + + + + + + +
+ +## Open-Sora: Democratizing Efficient Video Production for All + +We present **Open-Sora**, an initiative dedicated to **efficiently** produce high-quality video and make the model, +tools and contents accessible to all. By embracing **open-source** principles, +Open-Sora not only democratizes access to advanced video generation techniques, but also offers a +streamlined and user-friendly platform that simplifies the complexities of video production. +With Open-Sora, we aim to inspire innovation, creativity, and inclusivity in the realm of content creation. + +[[中文文档]](/docs/zh_CN/README.md) [[潞晨云部署视频教程]](https://www.bilibili.com/video/BV141421R7Ag) + +

Open-Sora is still at an early stage and under active development.

+ +## 📰 News + +* **[2024.04.25]** 🤗 We released the [Gradio demo for Open-Sora](https://huggingface.co/spaces/hpcai-tech/open-sora) on Hugging Face Spaces. +* **[2024.04.25]** 🔥 We released **Open-Sora 1.1**, which supports **2s~15s, 144p to 720p, any aspect ratio** text-to-image, **text-to-video, image-to-video, video-to-video, infinite time** generation. In addition, a full video processing pipeline is released. [[checkpoints]]() [[report]](/docs/report_02.md) +* **[2024.03.18]** We released **Open-Sora 1.0**, a fully open-source project for video generation. + Open-Sora 1.0 supports a full pipeline of video data preprocessing, training with + + acceleration, + inference, and more. Our model can produce 2s 512x512 videos with only 3 days training. [[checkpoints]](#open-sora-10-model-weights) + [[blog]](https://hpc-ai.com/blog/open-sora-v1.0) [[report]](docs/report_01.md) +* **[2024.03.04]** Open-Sora provides training with 46% cost reduction. + [[blog]](https://hpc-ai.com/blog/open-sora) + +## 🎥 Latest Demo + +🔥 You can experinece Open-Sora on our [🤗 Gradio application on Hugging Face](https://huggingface.co/spaces/hpcai-tech/open-sora). More samples are available in our [Gallery](https://hpcaitech.github.io/Open-Sora/). + +| **2s 240×426** | **2s 240×426** | +| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) | +| [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/f7ce4aaa-528f-40a8-be7a-72e61eaacbbd) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/5d58d71e-1fda-4d90-9ad3-5f2f7b75c6a9) | + +| **2s 426×240** | **4s 480×854** | +| ---------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/34ecb4a0-4eef-4286-ad4c-8e3a87e5a9fd) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c1619333-25d7-42ba-a91c-18dbc1870b18) | + +| **16s 320×320** | **16s 224×448** | **2s 426×240** | +| ------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [](https://github.com/hpcaitech/Open-Sora/assets/99191637/3cab536e-9b43-4b33-8da8-a0f9cf842ff2) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/9fb0b9e0-c6f4-4935-b29e-4cac10b373c4) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/3e892ad2-9543-4049-b005-643a4c1bf3bf) | +<<<<<<< Updated upstream +======= + +>>>>>>> Stashed changes + +
+OpenSora 1.0 Demo + +| **2s 512×512** | **2s 512×512** | **2s 512×512** | +| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- | +| [](https://github.com/hpcaitech/Open-Sora/assets/99191637/de1963d3-b43b-4e68-a670-bb821ebb6f80) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/13f8338f-3d42-4b71-8142-d234fbd746cc) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/fa6a65a6-e32a-4d64-9a9e-eabb0ebb8c16) | +| A serene night scene in a forested area. [...] The video is a time-lapse, capturing the transition from day to night, with the lake and forest serving as a constant backdrop. | A soaring drone footage captures the majestic beauty of a coastal cliff, [...] The water gently laps at the rock base and the greenery that clings to the top of the cliff. | The majestic beauty of a waterfall cascading down a cliff into a serene lake. [...] The camera angle provides a bird's eye view of the waterfall. | +| [](https://github.com/hpcaitech/Open-Sora/assets/99191637/64232f84-1b36-4750-a6c0-3e610fa9aa94) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/983a1965-a374-41a7-a76b-c07941a6c1e9) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/ec10c879-9767-4c31-865f-2e8d6cf11e65) | +| A bustling city street at night, filled with the glow of car headlights and the ambient light of streetlights. [...] | The vibrant beauty of a sunflower field. The sunflowers are arranged in neat rows, creating a sense of order and symmetry. [...] | A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell [...] | + +Videos are downsampled to `.gif` for display. Click for original videos. Prompts are trimmed for display, +see [here](/assets/texts/t2v_samples.txt) for full prompts. + +
+ +## 🔆 New Features/Updates + +* 📍 **Open-Sora 1.1** released. Model weights are available [here](). It is trained on **0s~15s, 144p to 720p, various aspect ratios** videos. See our **[report 1.1](docs/report_02.md)** for more discussions. +* 🔧 **Data processing pipeline v1.1** is released. An automatic [processing pipeline](#data-processing) from raw videos to (text, video clip) pairs is provided, including scene cutting $\rightarrow$ filtering(aesthetic, optical flow, OCR, etc.) $\rightarrow$ captioning $\rightarrow$ managing. With this tool, you can easily build your video dataset. +* ✅ Modified ST-DiT architecture includes rope positional encoding, qk norm, longer text length, etc. +* ✅ Support training with any resolution, aspect ratio, and duration (including images). +* ✅ Support image and video conditioning and video editing, and thus support animating images, connecting videos, etc. +* 📍 **Open-Sora 1.0** released. Model weights are available [here](#model-weights). With only 400K video clips and 200 H800 + days (compared with 152M samples in Stable Video Diffusion), we are able to generate 2s 512×512 videos. See our **[report 1.0](docs/report_01.md)** for more discussions. +* ✅ Three-stage training from an image diffusion model to a video diffusion model. We provide the weights for each + stage. +* ✅ Support training acceleration including accelerated transformer, faster T5 and VAE, and sequence parallelism. + Open-Sora improve **55%** training speed when training on 64x512x512 videos. Details locates + at [acceleration.md](docs/acceleration.md). +* 🔧 **Data preprocessing pipeline v1.0**, + including [downloading](/tools/datasets/README.md), [video cutting](/tools/scenedetect/README.md), + and [captioning](/tools/caption/README.md) tools. Our data collection plan can be found + at [datasets.md](docs/datasets.md). + +
+View more + +* ✅ We find VQ-VAE from [VideoGPT](https://wilson1yan.github.io/videogpt/index.html) has a low quality and thus adopt a + better VAE from [Stability-AI](https://huggingface.co/stabilityai/sd-vae-ft-mse-original). We also find patching in + the time dimension deteriorates the quality. See our **[report](docs/report_01.md)** for more discussions. +* ✅ We investigate different architectures including DiT, Latte, and our proposed STDiT. Our **STDiT** achieves a better + trade-off between quality and speed. See our **[report](docs/report_01.md)** for more discussions. +* ✅ Support clip and T5 text conditioning. +* ✅ By viewing images as one-frame videos, our project supports training DiT on both images and videos (e.g., ImageNet & + UCF101). See [commands.md](docs/commands.md) for more instructions. +* ✅ Support inference with official weights + from [DiT](https://github.com/facebookresearch/DiT), [Latte](https://github.com/Vchitect/Latte), + and [PixArt](https://pixart-alpha.github.io/). +* ✅ Refactor the codebase. See [structure.md](docs/structure.md) to learn the project structure and how to use the + config files. + +
+ +### TODO list sorted by priority + +* [ ] Training Video-VAE and adapt our model to new VAE. **[WIP]** +* [ ] Scaling model parameters and dataset size. **[WIP]** +* [ ] Incoporate a better scheduler, e.g., rectified flow in SD3. **[WIP]** + +
+View more + +* [x] Evaluation pipeline. +* [x] Complete the data processing pipeline (including dense optical flow, aesthetics scores, text-image similarity, etc.). +* [x] Support image and video conditioning. +* [x] Support variable aspect ratios, resolutions, durations. + +
+ +## Contents + +* [Installation](#installation) +* [Model Weights](#model-weights) +* [Inference](#inference) +* [Data Processing](#data-processing) +* [Training](#training) +* [Evaluation](#evaluation) +* [Contribution](#contribution) +* [Acknowledgement](#acknowledgement) + +Other useful documents and links are listed below. + +* Report: [report 1.1](docs/report_02.md), [report 1.0](docs/report_01.md), [acceleration.md](docs/acceleration.md) +* Repo structure: [structure.md](docs/structure.md) +* Config file explanation: [config.md](docs/config.md) +* Useful commands: [commands.md](docs/commands.md) +* Data processing pipeline and dataset: [datasets.md](docs/datasets.md) +* Each data processing tool's README: [dataset conventions and management](/tools/datasets/README.md), [scene cutting](/tools/scene_cut/README.md), [scoring](/tools/scoring/README.md), [caption](/tools/caption/README.md) +* Evaluation: [eval](/eval/README.md) +* Gallery: [gallery](https://hpcaitech.github.io/Open-Sora/) + +## Installation + +```bash +# create a virtual env +conda create -n opensora python=3.10 +# activate virtual environment +conda activate opensora + +# install torch +# the command below is for CUDA 12.1, choose install commands from +# https://pytorch.org/get-started/locally/ based on your own CUDA version +pip install torch torchvision + +# install flash attention (optional) +# set enable_flashattn=False in config to avoid using flash attention +pip install packaging ninja +pip install flash-attn --no-build-isolation + +# install apex (optional) +# set enable_layernorm_kernel=False in config to avoid using apex +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git + +# install xformers +pip install -U xformers --index-url https://download.pytorch.org/whl/cu121 + +# install this project +git clone https://github.com/hpcaitech/Open-Sora +cd Open-Sora +pip install -v . +``` + +## Model Weights + +### Open-Sora 1.1 Model Weights + +| Resolution | Model Size | Data | #iterations | Batch Size | URL | +| ------------------ | ---------- | -------------------------- | ----------- | ------------------------------------------------- | -------------------------------------------------------------------- | +| mainly 144p & 240p | 700M | 10M videos + 2M images | 100k | [dynamic](/configs/opensora-v1-1/train/stage2.py) | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v2-stage2) | +| 144p to 720p | 700M | 500K HQ videos + 1M images | 4k | [dynamic](/configs/opensora-v1-1/train/stage3.py) | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v2-stage3) | + +See our **[report 1.1](docs/report_02.md)** for more infomation. + +:warning: **LIMITATION**: This version contains known issues which we are going to fix in the next version (as we save computation resource for the next release). In addition, the video generation may fail for long duration, and high resolution will have noisy results due to this problem. + +### Open-Sora 1.0 Model Weights + +
+View more + +| Resolution | Model Size | Data | #iterations | Batch Size | GPU days (H800) | URL | +| ---------- | ---------- | ------ | ----------- | ---------- | --------------- | +| 16×512×512 | 700M | 20K HQ | 20k | 2×64 | 35 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x512x512.pth) | +| 16×256×256 | 700M | 20K HQ | 24k | 8×64 | 45 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x256x256.pth) | +| 16×256×256 | 700M | 366K | 80k | 8×64 | 117 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-16x256x256.pth) | + +Training orders: 16x256x256 $\rightarrow$ 16x256x256 HQ $\rightarrow$ 16x512x512 HQ. + +Our model's weight is partially initialized from [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha). The number of +parameters is 724M. More information about training can be found in our **[report](/docs/report_01.md)**. More about +the dataset can be found in [datasets.md](/docs/datasets.md). HQ means high quality. + +:warning: **LIMITATION**: Our model is trained on a limited budget. The quality and text alignment is relatively poor. +The model performs badly, especially on generating human beings and cannot follow detailed instructions. We are working +on improving the quality and text alignment. + +
+ +## Inference + +### Gradio Demo + +🔥 You can experinece Open-Sora on our [🤗 Gradio application](https://huggingface.co/spaces/hpcai-tech/open-sora) on Hugging Face online. + +If you want to deploy gradio locally, we have also provided a [Gradio application](./gradio) in this repository, you can use the following the command to start an interactive web application to experience video generation with Open-Sora. + +```bash +pip install gradio spaces +python gradio/app.py +``` + +This will launch a Gradio application on your localhost. If you want to know more about the Gradio applicaiton, you can refer to the [README file](./gradio/README.md). + +### Open-Sora 1.1 Command Line Inference + +Since Open-Sora 1.1 supports inference with dynamic input size, you can pass the input size as an argument. + +```bash +# text to video +python scripts/inference.py configs/opensora-v1-1/inference/sample.py \ + --ckpt-path CKPT_PATH --prompt "A beautiful sunset over the city" --num-frames 32 --image-size 480 854 +``` + +See [here](docs/commands.md#inference-with-open-sora-11) for more instructions including text-to-image, image-to-video, video-to-video, and infinite time generation. + +### Open-Sora 1.0 Command Line Inference + +
+View more + +We have also provided an offline inference script. Run the following commands to generate samples, the required model weights will be automatically downloaded. To change sampling prompts, modify the txt file passed to `--prompt-path`. See [here](docs/structure.md#inference-config-demos) to customize the configuration. + +```bash +# Sample 16x512x512 (20s/sample, 100 time steps, 24 GB memory) +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x512x512.py --ckpt-path OpenSora-v1-HQ-16x512x512.pth --prompt-path ./assets/texts/t2v_samples.txt + +# Sample 16x256x256 (5s/sample, 100 time steps, 22 GB memory) +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path OpenSora-v1-HQ-16x256x256.pth --prompt-path ./assets/texts/t2v_samples.txt + +# Sample 64x512x512 (40s/sample, 100 time steps) +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/64x512x512.py --ckpt-path ./path/to/your/ckpt.pth --prompt-path ./assets/texts/t2v_samples.txt + +# Sample 64x512x512 with sequence parallelism (30s/sample, 100 time steps) +# sequence parallelism is enabled automatically when nproc_per_node is larger than 1 +torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/64x512x512.py --ckpt-path ./path/to/your/ckpt.pth --prompt-path ./assets/texts/t2v_samples.txt +``` + +The speed is tested on H800 GPUs. For inference with other models, see [here](docs/commands.md) for more instructions. +To lower the memory usage, set a smaller `vae.micro_batch_size` in the config (slightly lower sampling speed). + +
+ +## Data Processing + +High-quality data is crucial for training good generation models. +To this end, we establish a complete pipeline for data processing, which could seamlessly convert raw videos to high-quality video-text pairs. +The pipeline is shown below. For detailed information, please refer to [data processing](docs/data_processing.md). +Also check out the [datasets](docs/datasets.md) we use. + +![Data Processing Pipeline](assets/readme/report_data_pipeline.png) + +## Training + +### Open-Sora 1.1 Training + +Once you prepare the data in a `csv` file, run the following commands to launch training on a single node. + +```bash +# one node +torchrun --standalone --nproc_per_node 8 scripts/train.py \ + configs/opensora-v1-1/train/stage1.py --data-path YOUR_CSV_PATH --ckpt-path YOUR_PRETRAINED_CKPT +# multiple nodes +colossalai run --nproc_per_node 8 --hostfile hostfile scripts/train.py \ + configs/opensora-v1-1/train/stage1.py --data-path YOUR_CSV_PATH --ckpt-path YOUR_PRETRAINED_CKPT +``` + +### Open-Sora 1.0 Training + +
+View more + +Once you prepare the data in a `csv` file, run the following commands to launch training on a single node. + +```bash +# 1 GPU, 16x256x256 +torchrun --nnodes=1 --nproc_per_node=1 scripts/train.py configs/opensora/train/16x256x256.py --data-path YOUR_CSV_PATH +# 8 GPUs, 64x512x512 +torchrun --nnodes=1 --nproc_per_node=8 scripts/train.py configs/opensora/train/64x512x512.py --data-path YOUR_CSV_PATH --ckpt-path YOUR_PRETRAINED_CKPT +``` + +To launch training on multiple nodes, prepare a hostfile according +to [ColossalAI](https://colossalai.org/docs/basics/launch_colossalai/#launch-with-colossal-ai-cli), and run the +following commands. + +```bash +colossalai run --nproc_per_node 8 --hostfile hostfile scripts/train.py configs/opensora/train/64x512x512.py --data-path YOUR_CSV_PATH --ckpt-path YOUR_PRETRAINED_CKPT +``` + +For training other models and advanced usage, see [here](docs/commands.md) for more instructions. + +
+ +## Evaluation + +See [here](eval/README.md) for more instructions. + +## Contribution + +Thanks goes to these wonderful contributors ([emoji key](https://allcontributors.org/docs/en/emoji-key) +following [all-contributors](https://github.com/all-contributors/all-contributors) specification): + + + + + + + + + + + + + + + + + + + + + +
zhengzangw
zhengzangw

💻 📖 🤔 📹 🚧
ver217
ver217

💻 🤔 📖 🐛
FrankLeeeee
FrankLeeeee

💻 🚇 🔧
xyupeng
xyupeng

💻 📖 🎨
Yanjia0
Yanjia0

📖
binmakeswell
binmakeswell

📖
eltociear
eltociear

📖
ganeshkrishnan1
ganeshkrishnan1

📖
fastalgo
fastalgo

📖
powerzbt
powerzbt

📖
+ + + + + + +If you wish to contribute to this project, you can refer to the [Contribution Guideline](./CONTRIBUTING.md). + +[Zangwei Zheng](https://github.com/zhengzangw) and [Xiangyu Peng](https://github.com/xyupeng) equally contributed to +this work during their internship at [HPC-AI Tech](https://hpc-ai.com/). + +## Acknowledgement + +* [ColossalAI](https://github.com/hpcaitech/ColossalAI): A powerful large model parallel acceleration and optimization + system. +* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers. +* [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT): An acceleration for DiT training. We adopt valuable acceleration + strategies for training progress from OpenDiT. +* [PixArt](https://github.com/PixArt-alpha/PixArt-alpha): An open-source DiT-based text-to-image model. +* [Latte](https://github.com/Vchitect/Latte): An attempt to efficiently train DiT for video. +* [StabilityAI VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse-original): A powerful image VAE model. +* [CLIP](https://github.com/openai/CLIP): A powerful text-image embedding model. +* [T5](https://github.com/google-research/text-to-text-transfer-transformer): A powerful text encoder. +* [LLaVA](https://github.com/haotian-liu/LLaVA): A powerful image captioning model based on [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) and [Yi-34B](https://huggingface.co/01-ai/Yi-34B). + +We are grateful for their exceptional work and generous contribution to open source. + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=hpcaitech/Open-Sora&type=Date)](https://star-history.com/#hpcaitech/Open-Sora&Date) diff --git a/src/videogen_hub/pipelines/opensora/__init__.py b/src/videogen_hub/pipelines/opensora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9843deda4254e45c42702bca6cbc13817896ca6a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/__init__.py @@ -0,0 +1,3 @@ +import sys + +sys.path.insert(0, "./src/videogen_hub/pipelines/opensora/") diff --git a/src/videogen_hub/pipelines/opensora/configs/__init__.py b/src/videogen_hub/pipelines/opensora/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/dit/__init__.py b/src/videogen_hub/pipelines/opensora/configs/dit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/dit/inference/16x256x256.py b/src/videogen_hub/pipelines/opensora/configs/dit/inference/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..44818fe095f5f16f960d5e7d0c7f974076aaeaa7 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/dit/inference/16x256x256.py @@ -0,0 +1,31 @@ +num_frames = 16 +fps = 8 +image_size = (256, 256) + +# Define model +model = dict( + type="DiT-XL/2", + condition="text", + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="clip", + from_pretrained="openai/clip-vit-base-patch32", + model_max_length=77, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/ucf101_labels.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/dit/inference/1x256x256-class.py b/src/videogen_hub/pipelines/opensora/configs/dit/inference/1x256x256-class.py new file mode 100644 index 0000000000000000000000000000000000000000..bebaa11e286db0ea7968723909482e18f28a12c3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/dit/inference/1x256x256-class.py @@ -0,0 +1,31 @@ +num_frames = 1 +fps = 1 +image_size = (256, 256) + +# Define model +model = dict( + type="DiT-XL/2", + no_temporal_pos_emb=True, + condition="label_1000", + from_pretrained="DiT-XL-2-256x256.pt", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="classes", + num_classes=1000, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/imagenet_id.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/dit/inference/1x256x256.py b/src/videogen_hub/pipelines/opensora/configs/dit/inference/1x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..e7cb9a2d20e6ae3a19e468f493f0e125cbb0a33f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/dit/inference/1x256x256.py @@ -0,0 +1,32 @@ +num_frames = 1 +fps = 1 +image_size = (256, 256) + +# Define model +model = dict( + type="DiT-XL/2", + no_temporal_pos_emb=True, + condition="text", + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="clip", + from_pretrained="openai/clip-vit-base-patch32", + model_max_length=77, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/imagenet_labels.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/dit/inference/__init__.py b/src/videogen_hub/pipelines/opensora/configs/dit/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/latte/__init__.py b/src/videogen_hub/pipelines/opensora/configs/latte/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/latte/inference/16x256x256-class.py b/src/videogen_hub/pipelines/opensora/configs/latte/inference/16x256x256-class.py new file mode 100644 index 0000000000000000000000000000000000000000..8ccf6d43604240e724f0e78f2de3aefa85449277 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/latte/inference/16x256x256-class.py @@ -0,0 +1,30 @@ +num_frames = 16 +fps = 8 +image_size = (256, 256) + +# Define model +model = dict( + type="Latte-XL/2", + condition="label_101", + from_pretrained="Latte-XL-2-256x256-ucf101.pt", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="classes", + num_classes=101, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/ucf101_id.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/latte/inference/16x256x256.py b/src/videogen_hub/pipelines/opensora/configs/latte/inference/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..6bdd58fad5f81bcca29c2d975fd2dd89a4bf7c58 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/latte/inference/16x256x256.py @@ -0,0 +1,31 @@ +num_frames = 16 +fps = 8 +image_size = (256, 256) + +# Define model +model = dict( + type="Latte-XL/2", + condition="text", + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="clip", + from_pretrained="openai/clip-vit-base-patch32", + model_max_length=77, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/ucf101_labels.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/latte/inference/__init__.py b/src/videogen_hub/pipelines/opensora/configs/latte/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/__init__.py b/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/inference/__init__.py b/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/inference/sample-ref.py b/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/inference/sample-ref.py new file mode 100644 index 0000000000000000000000000000000000000000..735c01baddca52af5134f656a5f93b6b3546ab9d --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/inference/sample-ref.py @@ -0,0 +1,70 @@ +num_frames = 16 +frame_interval = 3 +fps = 24 +image_size = (240, 426) +multi_resolution = "STDiT2" + +# Condition +prompt_path = None +prompt = [ + "A car driving on the ocean.", + 'Drone view of waves crashing against the rugged cliffs along Big Sur\'s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff\'s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff\'s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.{"reference_path": "assets/images/condition/cliff.png", "mask_strategy": "0"}', + "In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.", +] + +loop = 2 +condition_frame_length = 4 +# ( +# loop id, [the loop index of the condition image or video] +# reference id, [the index of the condition image or video in the reference_path] +# reference start, [the start frame of the condition image or video] +# target start, [the location to insert] +# length, [the number of frames to insert] +# edit_ratio [the edit rate of the condition image or video] +# ) +# See https://github.com/hpcaitech/Open-Sora/blob/main/docs/config.md#advanced-inference-config for more details +# See https://github.com/hpcaitech/Open-Sora/blob/main/docs/commands.md#inference-with-open-sora-11 for more examples +mask_strategy = [ + "0,0,0,0,8,0.3", + None, + "0", +] +reference_path = [ + "https://cdn.openai.com/tmp/s/interp/d0.mp4", + None, + "assets/images/condition/wave.png", +] + +# Define model +model = dict( + type="STDiT2-XL/2", + from_pretrained=None, + input_sq_size=512, + qk_norm=True, + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + cache_dir=None, # "/mnt/hdd/cached_models", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + cache_dir=None, # "/mnt/hdd/cached_models", + model_max_length=200, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, + cfg_channel=3, # or None +) +dtype = "bf16" + +# Others +batch_size = 1 +seed = 42 +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/inference/sample.py b/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/inference/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..cec80736ef2019c256148b91b1563919ca09f7b0 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/opensora-v1-1/inference/sample.py @@ -0,0 +1,43 @@ +num_frames = 16 +frame_interval = 3 +fps = 24 +image_size = (240, 426) +multi_resolution = "STDiT2" + +# Define model +model = dict( + type="STDiT2-XL/2", + from_pretrained=None, + input_sq_size=512, + qk_norm=True, + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + cache_dir=None, # "/mnt/hdd/cached_models", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + cache_dir=None, # "/mnt/hdd/cached_models", + model_max_length=200, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, + cfg_channel=3, # or None +) +dtype = "bf16" + +# Condition +prompt_path = "./assets/texts/t2v_samples.txt" +prompt = None # prompt has higher priority than prompt_path + +# Others +batch_size = 1 +seed = 42 +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/opensora/__init__.py b/src/videogen_hub/pipelines/opensora/configs/opensora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/opensora/inference/16x256x256.py b/src/videogen_hub/pipelines/opensora/configs/opensora/inference/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..50ead832a61c481632a821b330341505776a384e --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/opensora/inference/16x256x256.py @@ -0,0 +1,39 @@ +num_frames = 16 +fps = 24 // 3 +image_size = (256, 256) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=0.5, + time_scale=1.0, + enable_flashattn=True, + enable_layernorm_kernel=True, + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, + cfg_channel=3, # or None +) +dtype = "bf16" + +# Condition +prompt_path = "./assets/texts/t2v_samples.txt" +prompt = None # prompt has higher priority than prompt_path + +# Others +batch_size = 1 +seed = 42 +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/opensora/inference/16x512x512.py b/src/videogen_hub/pipelines/opensora/configs/opensora/inference/16x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..58d82437b762fbf0935eb97e86665f4eda5329cb --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/opensora/inference/16x512x512.py @@ -0,0 +1,35 @@ +num_frames = 16 +fps = 24 // 3 +image_size = (512, 512) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=1.0, + enable_flashattn=True, + enable_layernorm_kernel=True, + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=2, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2v_samples.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/opensora/inference/64x512x512.py b/src/videogen_hub/pipelines/opensora/configs/opensora/inference/64x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..dbbe2409823dbdb7e8628f705e64e9847172ddf2 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/opensora/inference/64x512x512.py @@ -0,0 +1,35 @@ +num_frames = 64 +fps = 24 // 2 +image_size = (512, 512) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=2 / 3, + enable_flashattn=True, + enable_layernorm_kernel=True, + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, +) +dtype = "bf16" + +# Others +batch_size = 1 +seed = 42 +prompt_path = "./assets/texts/t2v_samples.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/opensora/inference/__init__.py b/src/videogen_hub/pipelines/opensora/configs/opensora/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/pixart/__init__.py b/src/videogen_hub/pipelines/opensora/configs/pixart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/configs/pixart/inference/16x256x256.py b/src/videogen_hub/pipelines/opensora/configs/pixart/inference/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..5013c08739f54e174ab9394353f6055cca409e96 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/pixart/inference/16x256x256.py @@ -0,0 +1,32 @@ +num_frames = 16 +fps = 8 +image_size = (256, 256) + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=0.5, + time_scale=1.0, + from_pretrained="outputs/098-F16S3-PixArt-XL-2/epoch7-global_step30000/model_ckpt.pt", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2v_samples.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/pixart/inference/1x1024MS.py b/src/videogen_hub/pipelines/opensora/configs/pixart/inference/1x1024MS.py new file mode 100644 index 0000000000000000000000000000000000000000..e6af8c6773b2dde38be7203a98bfa2f59cde8901 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/pixart/inference/1x1024MS.py @@ -0,0 +1,34 @@ +num_frames = 1 +fps = 1 +image_size = (1920, 512) +multi_resolution = "PixArtMS" + +# Define model +model = dict( + type="PixArtMS-XL/2", + space_scale=2.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PixArt-XL-2-1024-MS.pth", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2i_samples.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/pixart/inference/1x256x256.py b/src/videogen_hub/pipelines/opensora/configs/pixart/inference/1x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..16f92602b6fab414726aad3a2cd3b79b0ee5abed --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/pixart/inference/1x256x256.py @@ -0,0 +1,33 @@ +num_frames = 1 +fps = 1 +image_size = (256, 256) + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PixArt-XL-2-256x256.pth", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2i_samples.txt" +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/pixart/inference/1x512x512.py b/src/videogen_hub/pipelines/opensora/configs/pixart/inference/1x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc90df5f51bff532b3309cb9f7140b267a00945 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/configs/pixart/inference/1x512x512.py @@ -0,0 +1,39 @@ +num_frames = 1 +fps = 1 +image_size = (512, 512) + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PixArt-XL-2-512x512.pth", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "bf16" + +# prompt_path = "./assets/texts/t2i_samples.txt" +prompt = [ + "Pirate ship trapped in a cosmic maelstrom nebula.", + "A small cactus with a happy face in the Sahara desert.", + "A small cactus with a sad face in the Sahara desert.", +] + +# Others +batch_size = 2 +seed = 42 +save_dir = "./samples/samples/" diff --git a/src/videogen_hub/pipelines/opensora/configs/pixart/inference/__init__.py b/src/videogen_hub/pipelines/opensora/configs/pixart/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/opensora/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/checkpoint.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..66ba530e2bffce699384a5f4f54e481a140be5ea --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/acceleration/checkpoint.py @@ -0,0 +1,24 @@ +from collections.abc import Iterable + +import torch.nn as nn +from torch.utils.checkpoint import checkpoint, checkpoint_sequential + + +def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): + assert isinstance(model, nn.Module) + + def set_attr(module): + module.grad_checkpointing = True + module.fp32_attention = use_fp32_attention + module.grad_checkpointing_step = gc_step + + model.apply(set_attr) + + +def auto_grad_checkpoint(module, *args, **kwargs): + if getattr(module, "grad_checkpointing", False): + if not isinstance(module, Iterable): + return checkpoint(module, *args, use_reentrant=False, **kwargs) + gc_step = module[0].grad_checkpointing_step + return checkpoint_sequential(module, gc_step, *args, use_reentrant=False, **kwargs) + return module(*args, **kwargs) diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/communications.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/communications.py new file mode 100644 index 0000000000000000000000000000000000000000..d0900d20841248a250b5aeb31755fac689474ff8 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/acceleration/communications.py @@ -0,0 +1,188 @@ +import torch +import torch.distributed as dist + + +# ==================== +# All-To-All +# ==================== +def _all_to_all( + input_: torch.Tensor, + world_size: int, + group: dist.ProcessGroup, + scatter_dim: int, + gather_dim: int, +): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.world_size = dist.get_world_size(process_group) + output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim) + return output + + @staticmethod + def backward(ctx, grad_output): + grad_output = _all_to_all( + grad_output, + ctx.world_size, + ctx.process_group, + ctx.gather_dim, + ctx.scatter_dim, + ) + return ( + grad_output, + None, + None, + None, + ) + + +def all_to_all( + input_: torch.Tensor, + process_group: dist.ProcessGroup, + scatter_dim: int = 2, + gather_dim: int = 1, +): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) + + +def _gather( + input_: torch.Tensor, + world_size: int, + group: dist.ProcessGroup, + gather_dim: int, +): + if gather_list is None: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + dist.gather(input_, gather_list, group=group, gather_dim=gather_dim) + return gather_list + + +# ==================== +# Gather-Split +# ==================== + + +def _split(input_, pg: dist.ProcessGroup, dim=-1): + # skip if only one rank involved + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, pg: dist.ProcessGroup, dim=-1): + # skip if only one rank involved + input_ = input_.contiguous() + world_size = dist.get_world_size(pg) + dist.get_rank(pg) + + if world_size == 1: + return input_ + + # all gather + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + assert input_.device.type == "cuda" + torch.distributed.all_gather(tensor_list, input_, group=pg) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim, grad_scale): + ctx.mode = process_group + ctx.dim = dim + ctx.grad_scale = grad_scale + return _gather(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale == "up": + grad_output = grad_output * dist.get_world_size(ctx.mode) + elif ctx.grad_scale == "down": + grad_output = grad_output / dist.get_world_size(ctx.mode) + + return _split(grad_output, ctx.mode, ctx.dim), None, None, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim, grad_scale): + ctx.mode = process_group + ctx.dim = dim + ctx.grad_scale = grad_scale + return _split(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale == "up": + grad_output = grad_output * dist.get_world_size(ctx.mode) + elif ctx.grad_scale == "down": + grad_output = grad_output / dist.get_world_size(ctx.mode) + return _gather(grad_output, ctx.mode, ctx.dim), None, None, None + + +def split_forward_gather_backward(input_, process_group, dim, grad_scale=1.0): + return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale) + + +def gather_forward_split_backward(input_, process_group, dim, grad_scale=None): + return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale) diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/parallel_states.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/parallel_states.py new file mode 100644 index 0000000000000000000000000000000000000000..3c05cf137045690d47d350d7a4e33fe724b4071c --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/acceleration/parallel_states.py @@ -0,0 +1,19 @@ +import torch.distributed as dist + +_GLOBAL_PARALLEL_GROUPS = dict() + + +def set_data_parallel_group(group: dist.ProcessGroup): + _GLOBAL_PARALLEL_GROUPS["data"] = group + + +def get_data_parallel_group(): + return _GLOBAL_PARALLEL_GROUPS.get("data", dist.group.WORLD) + + +def set_sequence_parallel_group(group: dist.ProcessGroup): + _GLOBAL_PARALLEL_GROUPS["sequence"] = group + + +def get_sequence_parallel_group(): + return _GLOBAL_PARALLEL_GROUPS.get("sequence", None) diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/plugin.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..c657a9539d8fb1f0d65e8f452777a4bb73a84d4d --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/acceleration/plugin.py @@ -0,0 +1,100 @@ +import random +from typing import Optional + +import numpy as np +import torch +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import ProcessGroupMesh +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +DP_AXIS, SP_AXIS = 0, 1 + + +class ZeroSeqParallelPlugin(LowLevelZeroPlugin): + def __init__( + self, + sp_size: int = 1, + stage: int = 2, + precision: str = "fp16", + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + reduce_bucket_size_in_m: int = 12, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + cpu_offload: bool = False, + master_weights: bool = True, + verbose: bool = False, + ) -> None: + super().__init__( + stage=stage, + precision=precision, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type, + reduce_bucket_size_in_m=reduce_bucket_size_in_m, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + master_weights=master_weights, + verbose=verbose, + ) + self.sp_size = sp_size + assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size" + self.dp_size = self.world_size // sp_size + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) + self.dp_rank = self.pg_mesh.coordinate(DP_AXIS) + self.sp_rank = self.pg_mesh.coordinate(SP_AXIS) + + def __del__(self): + """Destroy the prcess groups in ProcessGroupMesh""" + self.pg_mesh.destroy_mesh_process_groups() + + def prepare_dataloader( + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, + ): + _kwargs = kwargs.copy() + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls(dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/modeling/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/modeling/t5.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/modeling/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfb80841c92a57628fba81425627053afc76a3b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/modeling/t5.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + @staticmethod + def from_native_module(module, *args, **kwargs): + assert module.__class__.__name__ == "FusedRMSNorm", ( + "Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm." + "Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48" + ) + + layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps) + layer_norm.weight.data.copy_(module.weight.data) + layer_norm = layer_norm.to(module.weight.device) + return layer_norm diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/policy/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/policy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/policy/t5_encoder.py b/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/policy/t5_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c602b828ee0850362a7c478e7204208ce5857137 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/acceleration/shardformer/policy/t5_encoder.py @@ -0,0 +1,75 @@ +try: + from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func + from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, \ + get_T5_layer_self_attention_forward + from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription +except ImportError: + get_jit_fused_dropout_add_func = None + get_jit_fused_T5_layer_ff_forward = None + get_T5_layer_self_attention_forward = None + Policy = object + SubModuleReplacementDescription = object + + +class T5EncoderPolicy(Policy): + def config_sanity_check(self): + assert not self.shard_config.enable_tensor_parallelism + assert not self.shard_config.enable_flash_attention + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack + + policy = {} + + # check whether apex is installed + try: + from colossalai.acceleration.shardformer.modeling.t5 import T5LayerNorm + + # recover hf from fused rms norm to T5 norm which is faster + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=T5LayerNorm, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm), + policy=policy, + target_key=T5LayerSelfAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm), + policy=policy, + target_key=T5Stack, + ) + except (ImportError, ModuleNotFoundError): + pass + + # use jit operator + if self.shard_config.enable_jit_fused and get_jit_fused_T5_layer_ff_forward and get_jit_fused_dropout_add_func: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_T5_layer_ff_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_T5_layer_self_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerSelfAttention, + ) + + return policy + + def postprocess(self): + return self.model diff --git a/src/videogen_hub/pipelines/opensora/opensora/datasets/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32f210ff4e2032a502ec1c2706c2c6b47a77186f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/datasets/__init__.py @@ -0,0 +1,2 @@ +from videogen_hub.pipelines.opensora.opensora.datasets.datasets import IMG_FPS, BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset +from videogen_hub.pipelines.opensora.opensora.datasets.utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample diff --git a/src/videogen_hub/pipelines/opensora/opensora/datasets/aspect.py b/src/videogen_hub/pipelines/opensora/opensora/datasets/aspect.py new file mode 100644 index 0000000000000000000000000000000000000000..44c7839e243a545360b6318305482f09eb81afec --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/datasets/aspect.py @@ -0,0 +1,492 @@ +import math + + +# computation +def get_h_w(a, ts, eps=1e-4): + h = (ts * a) ** 0.5 + h = h + eps + h = math.ceil(h) if math.ceil(h) % 2 == 0 else math.floor(h) + w = h / a + w = w + eps + w = math.ceil(w) if math.ceil(w) % 2 == 0 else math.floor(w) + return h, w + + +def get_aspect_ratios_dict(ars, ts=360 * 640): + est = {f"{a:.2f}": get_h_w(a, ts) for a in ars} + return est + + +def get_ar(ratio): + h, w = ratio.split(":") + return int(h) / int(w) + + +# H:W +ASPECT_RATIO_MAP = { + "3:8": "0.38", + "9:21": "0.43", + "12:25": "0.48", + "1:2": "0.50", + "9:17": "0.53", + "27:50": "0.54", + "9:16": "0.56", + "5:8": "0.62", + "2:3": "0.67", + "3:4": "0.75", + "1:1": "1.00", + "4:3": "1.33", + "3:2": "1.50", + "16:9": "1.78", + "17:9": "1.89", + "2:1": "2.00", + "50:27": "2.08", +} + + +AR = [get_ar(ratio) for ratio in ASPECT_RATIO_MAP.keys()] + +# computed from above code +# S = 8294400 +ASPECT_RATIO_4K = { + "0.38": (1764, 4704), + "0.43": (1886, 4400), + "0.48": (1996, 4158), + "0.50": (2036, 4072), + "0.53": (2096, 3960), + "0.54": (2118, 3918), + "0.62": (2276, 3642), + "0.56": (2160, 3840), # base + "0.67": (2352, 3528), + "0.75": (2494, 3326), + "1.00": (2880, 2880), + "1.33": (3326, 2494), + "1.50": (3528, 2352), + "1.78": (3840, 2160), + "1.89": (3958, 2096), + "2.00": (4072, 2036), + "2.08": (4156, 1994), +} + +# S = 3686400 +ASPECT_RATIO_2K = { + "0.38": (1176, 3136), + "0.43": (1256, 2930), + "0.48": (1330, 2770), + "0.50": (1358, 2716), + "0.53": (1398, 2640), + "0.54": (1412, 2612), + "0.56": (1440, 2560), # base + "0.62": (1518, 2428), + "0.67": (1568, 2352), + "0.75": (1662, 2216), + "1.00": (1920, 1920), + "1.33": (2218, 1664), + "1.50": (2352, 1568), + "1.78": (2560, 1440), + "1.89": (2638, 1396), + "2.00": (2716, 1358), + "2.08": (2772, 1330), +} + +# S = 2073600 +ASPECT_RATIO_1080P = { + "0.38": (882, 2352), + "0.43": (942, 2198), + "0.48": (998, 2080), + "0.50": (1018, 2036), + "0.53": (1048, 1980), + "0.54": (1058, 1958), + "0.56": (1080, 1920), # base + "0.62": (1138, 1820), + "0.67": (1176, 1764), + "0.75": (1248, 1664), + "1.00": (1440, 1440), + "1.33": (1662, 1246), + "1.50": (1764, 1176), + "1.78": (1920, 1080), + "1.89": (1980, 1048), + "2.00": (2036, 1018), + "2.08": (2078, 998), +} + +# S = 921600 +ASPECT_RATIO_720P = { + "0.38": (588, 1568), + "0.43": (628, 1466), + "0.48": (666, 1388), + "0.50": (678, 1356), + "0.53": (698, 1318), + "0.54": (706, 1306), + "0.56": (720, 1280), # base + "0.62": (758, 1212), + "0.67": (784, 1176), + "0.75": (832, 1110), + "1.00": (960, 960), + "1.33": (1108, 832), + "1.50": (1176, 784), + "1.78": (1280, 720), + "1.89": (1320, 698), + "2.00": (1358, 680), + "2.08": (1386, 666), +} + +# S = 409920 +ASPECT_RATIO_480P = { + "0.38": (392, 1046), + "0.43": (420, 980), + "0.48": (444, 925), + "0.50": (452, 904), + "0.53": (466, 880), + "0.54": (470, 870), + "0.56": (480, 854), # base + "0.62": (506, 810), + "0.67": (522, 784), + "0.75": (554, 738), + "1.00": (640, 640), + "1.33": (740, 555), + "1.50": (784, 522), + "1.78": (854, 480), + "1.89": (880, 466), + "2.00": (906, 454), + "2.08": (924, 444), +} + +# S = 230400 +ASPECT_RATIO_360P = { + "0.38": (294, 784), + "0.43": (314, 732), + "0.48": (332, 692), + "0.50": (340, 680), + "0.53": (350, 662), + "0.54": (352, 652), + "0.56": (360, 640), # base + "0.62": (380, 608), + "0.67": (392, 588), + "0.75": (416, 554), + "1.00": (480, 480), + "1.33": (554, 416), + "1.50": (588, 392), + "1.78": (640, 360), + "1.89": (660, 350), + "2.00": (678, 340), + "2.08": (692, 332), +} + +# S = 102240 +ASPECT_RATIO_240P = { + "0.38": (196, 522), + "0.43": (210, 490), + "0.48": (222, 462), + "0.50": (226, 452), + "0.53": (232, 438), + "0.54": (236, 436), + "0.56": (240, 426), # base + "0.62": (252, 404), + "0.67": (262, 393), + "0.75": (276, 368), + "1.00": (320, 320), + "1.33": (370, 278), + "1.50": (392, 262), + "1.78": (426, 240), + "1.89": (440, 232), + "2.00": (452, 226), + "2.08": (462, 222), +} + +# S = 36864 +ASPECT_RATIO_144P = { + "0.38": (117, 312), + "0.43": (125, 291), + "0.48": (133, 277), + "0.50": (135, 270), + "0.53": (139, 262), + "0.54": (141, 260), + "0.56": (144, 256), # base + "0.62": (151, 241), + "0.67": (156, 234), + "0.75": (166, 221), + "1.00": (192, 192), + "1.33": (221, 165), + "1.50": (235, 156), + "1.78": (256, 144), + "1.89": (263, 139), + "2.00": (271, 135), + "2.08": (277, 132), +} + +# from PixArt +# S = 8294400 +ASPECT_RATIO_2880 = { + "0.25": (1408, 5760), + "0.26": (1408, 5568), + "0.27": (1408, 5376), + "0.28": (1408, 5184), + "0.32": (1600, 4992), + "0.33": (1600, 4800), + "0.34": (1600, 4672), + "0.4": (1792, 4480), + "0.42": (1792, 4288), + "0.47": (1920, 4096), + "0.49": (1920, 3904), + "0.51": (1920, 3776), + "0.55": (2112, 3840), + "0.59": (2112, 3584), + "0.68": (2304, 3392), + "0.72": (2304, 3200), + "0.78": (2496, 3200), + "0.83": (2496, 3008), + "0.89": (2688, 3008), + "0.93": (2688, 2880), + "1.0": (2880, 2880), + "1.07": (2880, 2688), + "1.12": (3008, 2688), + "1.21": (3008, 2496), + "1.28": (3200, 2496), + "1.39": (3200, 2304), + "1.47": (3392, 2304), + "1.7": (3584, 2112), + "1.82": (3840, 2112), + "2.03": (3904, 1920), + "2.13": (4096, 1920), + "2.39": (4288, 1792), + "2.5": (4480, 1792), + "2.92": (4672, 1600), + "3.0": (4800, 1600), + "3.12": (4992, 1600), + "3.68": (5184, 1408), + "3.82": (5376, 1408), + "3.95": (5568, 1408), + "4.0": (5760, 1408), +} + +# S = 4194304 +ASPECT_RATIO_2048 = { + "0.25": (1024, 4096), + "0.26": (1024, 3968), + "0.27": (1024, 3840), + "0.28": (1024, 3712), + "0.32": (1152, 3584), + "0.33": (1152, 3456), + "0.35": (1152, 3328), + "0.4": (1280, 3200), + "0.42": (1280, 3072), + "0.48": (1408, 2944), + "0.5": (1408, 2816), + "0.52": (1408, 2688), + "0.57": (1536, 2688), + "0.6": (1536, 2560), + "0.68": (1664, 2432), + "0.72": (1664, 2304), + "0.78": (1792, 2304), + "0.82": (1792, 2176), + "0.88": (1920, 2176), + "0.94": (1920, 2048), + "1.0": (2048, 2048), + "1.07": (2048, 1920), + "1.13": (2176, 1920), + "1.21": (2176, 1792), + "1.29": (2304, 1792), + "1.38": (2304, 1664), + "1.46": (2432, 1664), + "1.67": (2560, 1536), + "1.75": (2688, 1536), + "2.0": (2816, 1408), + "2.09": (2944, 1408), + "2.4": (3072, 1280), + "2.5": (3200, 1280), + "2.89": (3328, 1152), + "3.0": (3456, 1152), + "3.11": (3584, 1152), + "3.62": (3712, 1024), + "3.75": (3840, 1024), + "3.88": (3968, 1024), + "4.0": (4096, 1024), +} + +# S = 1048576 +ASPECT_RATIO_1024 = { + "0.25": (512, 2048), + "0.26": (512, 1984), + "0.27": (512, 1920), + "0.28": (512, 1856), + "0.32": (576, 1792), + "0.33": (576, 1728), + "0.35": (576, 1664), + "0.4": (640, 1600), + "0.42": (640, 1536), + "0.48": (704, 1472), + "0.5": (704, 1408), + "0.52": (704, 1344), + "0.57": (768, 1344), + "0.6": (768, 1280), + "0.68": (832, 1216), + "0.72": (832, 1152), + "0.78": (896, 1152), + "0.82": (896, 1088), + "0.88": (960, 1088), + "0.94": (960, 1024), + "1.0": (1024, 1024), + "1.07": (1024, 960), + "1.13": (1088, 960), + "1.21": (1088, 896), + "1.29": (1152, 896), + "1.38": (1152, 832), + "1.46": (1216, 832), + "1.67": (1280, 768), + "1.75": (1344, 768), + "2.0": (1408, 704), + "2.09": (1472, 704), + "2.4": (1536, 640), + "2.5": (1600, 640), + "2.89": (1664, 576), + "3.0": (1728, 576), + "3.11": (1792, 576), + "3.62": (1856, 512), + "3.75": (1920, 512), + "3.88": (1984, 512), + "4.0": (2048, 512), +} + +# S = 262144 +ASPECT_RATIO_512 = { + "0.25": (256, 1024), + "0.26": (256, 992), + "0.27": (256, 960), + "0.28": (256, 928), + "0.32": (288, 896), + "0.33": (288, 864), + "0.35": (288, 832), + "0.4": (320, 800), + "0.42": (320, 768), + "0.48": (352, 736), + "0.5": (352, 704), + "0.52": (352, 672), + "0.57": (384, 672), + "0.6": (384, 640), + "0.68": (416, 608), + "0.72": (416, 576), + "0.78": (448, 576), + "0.82": (448, 544), + "0.88": (480, 544), + "0.94": (480, 512), + "1.0": (512, 512), + "1.07": (512, 480), + "1.13": (544, 480), + "1.21": (544, 448), + "1.29": (576, 448), + "1.38": (576, 416), + "1.46": (608, 416), + "1.67": (640, 384), + "1.75": (672, 384), + "2.0": (704, 352), + "2.09": (736, 352), + "2.4": (768, 320), + "2.5": (800, 320), + "2.89": (832, 288), + "3.0": (864, 288), + "3.11": (896, 288), + "3.62": (928, 256), + "3.75": (960, 256), + "3.88": (992, 256), + "4.0": (1024, 256), +} + +# S = 65536 +ASPECT_RATIO_256 = { + "0.25": (128, 512), + "0.26": (128, 496), + "0.27": (128, 480), + "0.28": (128, 464), + "0.32": (144, 448), + "0.33": (144, 432), + "0.35": (144, 416), + "0.4": (160, 400), + "0.42": (160, 384), + "0.48": (176, 368), + "0.5": (176, 352), + "0.52": (176, 336), + "0.57": (192, 336), + "0.6": (192, 320), + "0.68": (208, 304), + "0.72": (208, 288), + "0.78": (224, 288), + "0.82": (224, 272), + "0.88": (240, 272), + "0.94": (240, 256), + "1.0": (256, 256), + "1.07": (256, 240), + "1.13": (272, 240), + "1.21": (272, 224), + "1.29": (288, 224), + "1.38": (288, 208), + "1.46": (304, 208), + "1.67": (320, 192), + "1.75": (336, 192), + "2.0": (352, 176), + "2.09": (368, 176), + "2.4": (384, 160), + "2.5": (400, 160), + "2.89": (416, 144), + "3.0": (432, 144), + "3.11": (448, 144), + "3.62": (464, 128), + "3.75": (480, 128), + "3.88": (496, 128), + "4.0": (512, 128), +} + + +def get_closest_ratio(height: float, width: float, ratios: dict): + aspect_ratio = height / width + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return closest_ratio + + +ASPECT_RATIOS = { + "144p": (36864, ASPECT_RATIO_144P), + "256": (65536, ASPECT_RATIO_256), + "240p": (102240, ASPECT_RATIO_240P), + "360p": (230400, ASPECT_RATIO_360P), + "512": (262144, ASPECT_RATIO_512), + "480p": (409920, ASPECT_RATIO_480P), + "720p": (921600, ASPECT_RATIO_720P), + "1024": (1048576, ASPECT_RATIO_1024), + "1080p": (2073600, ASPECT_RATIO_1080P), + "2k": (3686400, ASPECT_RATIO_2K), + "2048": (4194304, ASPECT_RATIO_2048), + "2880": (8294400, ASPECT_RATIO_2880), + "4k": (8294400, ASPECT_RATIO_4K), +} + + +def get_num_pixels(name): + return ASPECT_RATIOS[name][0] + + +def get_image_size(resolution, ar_ratio): + ar_key = ASPECT_RATIO_MAP[ar_ratio] + rs_dict = ASPECT_RATIOS[resolution][1] + assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}" + return rs_dict[ar_key] + + +NUM_FRAMES_MAP = { + "1x": 51, + "2x": 102, + "4x": 204, + "8x": 408, + "16x": 816, + "2s": 51, + "4s": 102, + "8s": 204, + "16s": 408, + "32s": 816, +} + + +def get_num_frames(num_frames): + if num_frames in NUM_FRAMES_MAP: + return NUM_FRAMES_MAP[num_frames] + else: + return int(num_frames) diff --git a/src/videogen_hub/pipelines/opensora/opensora/datasets/bucket.py b/src/videogen_hub/pipelines/opensora/opensora/datasets/bucket.py new file mode 100644 index 0000000000000000000000000000000000000000..e5b5ec5dfb2965afe6bb1b9c059ec008d1194625 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/datasets/bucket.py @@ -0,0 +1,142 @@ +from collections import OrderedDict + +import numpy as np + + +from .aspect import ASPECT_RATIOS, get_closest_ratio +from ..utils.misc import get_logger + + +def find_approximate_hw(hw, hw_dict, approx=0.8): + for k, v in hw_dict.items(): + if hw >= v * approx: + return k + return None + + +def find_closet_smaller_bucket(t, t_dict, frame_interval): + # process image + if t == 1: + if 1 in t_dict: + return 1 + else: + return None + # process video + for k, v in t_dict.items(): + if t >= v * frame_interval and v != 1: + return k + return None + + +class Bucket: + def __init__(self, bucket_config): + for key in bucket_config: + assert key in ASPECT_RATIOS, f"Aspect ratio {key} not found." + # wrap config with OrderedDict + bucket_probs = OrderedDict() + bucket_bs = OrderedDict() + bucket_names = sorted(bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True) + for key in bucket_names: + bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True) + bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names}) + bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names}) + + # first level: HW + num_bucket = 0 + hw_criteria = dict() + t_criteria = dict() + ar_criteria = dict() + bucket_id = OrderedDict() + bucket_id_cnt = 0 + for k1, v1 in bucket_probs.items(): + hw_criteria[k1] = ASPECT_RATIOS[k1][0] + t_criteria[k1] = dict() + ar_criteria[k1] = dict() + bucket_id[k1] = dict() + for k2, _ in v1.items(): + t_criteria[k1][k2] = k2 + bucket_id[k1][k2] = bucket_id_cnt + bucket_id_cnt += 1 + ar_criteria[k1][k2] = dict() + for k3, v3 in ASPECT_RATIOS[k1][1].items(): + ar_criteria[k1][k2][k3] = v3 + num_bucket += 1 + + self.bucket_probs = bucket_probs + self.bucket_bs = bucket_bs + self.bucket_id = bucket_id + self.hw_criteria = hw_criteria + self.t_criteria = t_criteria + self.ar_criteria = ar_criteria + self.num_bucket = num_bucket + get_logger().info("Number of buckets: %s", num_bucket) + + def get_bucket_id(self, T, H, W, frame_interval=1, seed=None): + resolution = H * W + approx = 0.8 + + fail = True + for hw_id, t_criteria in self.bucket_probs.items(): + if resolution < self.hw_criteria[hw_id] * approx: + continue + + # if sample is an image + if T == 1: + if 1 in t_criteria: + rng = np.random.default_rng(seed + self.bucket_id[hw_id][1]) + if rng.random() < t_criteria[1]: + fail = False + t_id = 1 + break + else: + continue + + # otherwise, find suitable t_id for video + t_fail = True + for t_id, prob in t_criteria.items(): + rng = np.random.default_rng(seed + self.bucket_id[hw_id][t_id]) + if isinstance(prob, tuple): + prob_t = prob[1] + if rng.random() > prob_t: + continue + if T > t_id * frame_interval and t_id != 1: + t_fail = False + break + if t_fail: + continue + + # leave the loop if prob is high enough + if isinstance(prob, tuple): + prob = prob[0] + if prob >= 1 or rng.random() < prob: + fail = False + break + if fail: + return None + + # get aspect ratio id + ar_criteria = self.ar_criteria[hw_id][t_id] + ar_id = get_closest_ratio(H, W, ar_criteria) + return hw_id, t_id, ar_id + + def get_thw(self, bucket_id): + assert len(bucket_id) == 3 + T = self.t_criteria[bucket_id[0]][bucket_id[1]] + H, W = self.ar_criteria[bucket_id[0]][bucket_id[1]][bucket_id[2]] + return T, H, W + + def get_prob(self, bucket_id): + return self.bucket_probs[bucket_id[0]][bucket_id[1]] + + def get_batch_size(self, bucket_id): + return self.bucket_bs[bucket_id[0]][bucket_id[1]] + + def __len__(self): + return self.num_bucket + + +def closet_smaller_bucket(value, bucket): + for i in range(1, len(bucket)): + if value < bucket[i]: + return bucket[i - 1] + return bucket[-1] diff --git a/src/videogen_hub/pipelines/opensora/opensora/datasets/dataloader.py b/src/videogen_hub/pipelines/opensora/opensora/datasets/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..a8fa1f9f6458b51362bd88db08e269ad15996841 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/datasets/dataloader.py @@ -0,0 +1,145 @@ +import collections +import random +from typing import Optional + +import numpy as np +import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import DataLoader + +from videogen_hub.pipelines.opensora.opensora.datasets.datasets import BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset +from videogen_hub.pipelines.opensora.opensora.datasets.sampler import BatchDistributedSampler, StatefulDistributedSampler, VariableVideoBatchSampler + + +# Deterministic dataloader +def get_seed_worker(seed): + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return seed_worker + + +def prepare_dataloader( + dataset, + batch_size=None, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + bucket_config=None, + num_bucket_build_workers=1, + **kwargs, +): + _kwargs = kwargs.copy() + if isinstance(dataset, VariableVideoTextDataset): + batch_sampler = VariableVideoBatchSampler( + dataset, + bucket_config, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + verbose=True, + num_bucket_build_workers=num_bucket_build_workers, + ) + return ( + DataLoader( + dataset, + batch_sampler=batch_sampler, + worker_init_fn=get_seed_worker(seed), + pin_memory=pin_memory, + num_workers=num_workers, + collate_fn=collate_fn_default, + **_kwargs, + ), + batch_sampler, + ) + elif isinstance(dataset, VideoTextDataset): + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + ) + return ( + DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=get_seed_worker(seed), + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + collate_fn=collate_fn_default, + **_kwargs, + ), + sampler, + ) + elif isinstance(dataset, BatchFeatureDataset): + sampler = BatchDistributedSampler( + dataset, + num_replicas=process_group.size(), + rank=process_group.rank(), + ) + return ( + DataLoader( + dataset, + batch_size=1, + sampler=sampler, + worker_init_fn=get_seed_worker(seed), + pin_memory=pin_memory, + num_workers=num_workers, + collate_fn=collate_fn_batch, + **_kwargs, + ), + sampler, + ) + else: + raise ValueError(f"Unsupported dataset type: {type(dataset)}") + + +def collate_fn_default(batch): + # HACK: for loading text features + use_mask = False + if "mask" in batch[0] and isinstance(batch[0]["mask"], int): + masks = [x.pop("mask") for x in batch] + + texts = [x.pop("text") for x in batch] + texts = torch.cat(texts, dim=1) + use_mask = True + + ret = torch.utils.data.default_collate(batch) + + if use_mask: + ret["mask"] = masks + ret["text"] = texts + return ret + + +def collate_fn_batch(batch): + """ + Used only with BatchDistributedSampler + """ + res = torch.utils.data.default_collate(batch) + + # squeeze the first dimension, which is due to torch.stack() in default_collate() + if isinstance(res, collections.abc.Mapping): + for k, v in res.items(): + if isinstance(v, torch.Tensor): + res[k] = v.squeeze(0) + elif isinstance(res, collections.abc.Sequence): + res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res] + elif isinstance(res, torch.Tensor): + res = res.squeeze(0) + else: + raise TypeError + + return res diff --git a/src/videogen_hub/pipelines/opensora/opensora/datasets/datasets.py b/src/videogen_hub/pipelines/opensora/opensora/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..47ca62243c728e0d8d1eddcffface3ede84319f3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/datasets/datasets.py @@ -0,0 +1,245 @@ +import os +from glob import glob + +import numpy as np +import torch +from PIL import ImageFile +from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader + + +from videogen_hub.pipelines.opensora.opensora.datasets.read_video import read_video +from videogen_hub.pipelines.opensora.opensora.datasets.utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, read_file, temporal_random_crop +from videogen_hub.pipelines.opensora.opensora.registry import DATASETS + +ImageFile.LOAD_TRUNCATED_IMAGES = True +IMG_FPS = 120 + + +@DATASETS.register_module() +class VideoTextDataset(torch.utils.data.Dataset): + """load video according to the csv file. + + Args: + target_video_len (int): the number of video frames will be load. + align_transform (callable): Align different videos in a specified size. + temporal_sample (callable): Sample the target length of a video. + """ + + def __init__( + self, + data_path=None, + num_frames=16, + frame_interval=1, + image_size=(256, 256), + transform_name="center", + ): + self.data_path = data_path + self.data = read_file(data_path) + self.get_text = "text" in self.data.columns + self.num_frames = num_frames + self.frame_interval = frame_interval + self.image_size = image_size + self.transforms = { + "image": get_transforms_image(transform_name, image_size), + "video": get_transforms_video(transform_name, image_size), + } + + def _print_data_number(self): + num_videos = 0 + num_images = 0 + for path in self.data["path"]: + if self.get_type(path) == "video": + num_videos += 1 + else: + num_images += 1 + print(f"Dataset contains {num_videos} videos and {num_images} images.") + + def get_type(self, path): + ext = os.path.splitext(path)[-1].lower() + if ext.lower() in VID_EXTENSIONS: + return "video" + else: + assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + return "image" + + def getitem(self, index): + sample = self.data.iloc[index] + path = sample["path"] + file_type = self.get_type(path) + + if file_type == "video": + # loading + vframes, vinfo = read_video(path, backend="av") + video_fps = vinfo["video_fps"] if "video_fps" in vinfo else 24 + + # Sampling video frames + video = temporal_random_crop(vframes, self.num_frames, self.frame_interval) + + # transform + transform = self.transforms["video"] + video = transform(video) # T C H W + else: + # loading + image = pil_loader(path) + video_fps = IMG_FPS + + # transform + transform = self.transforms["image"] + image = transform(image) + + # repeat + video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) + + # TCHW -> CTHW + video = video.permute(1, 0, 2, 3) + + ret = {"video": video, "fps": video_fps} + if self.get_text: + ret["text"] = sample["text"] + return ret + + def __getitem__(self, index): + for _ in range(10): + try: + return self.getitem(index) + except Exception as e: + path = self.data.iloc[index]["path"] + print(f"data {path}: {e}") + index = np.random.randint(len(self)) + raise RuntimeError("Too many bad data.") + + def __len__(self): + return len(self.data) + + +@DATASETS.register_module() +class VariableVideoTextDataset(VideoTextDataset): + def __init__( + self, + data_path=None, + num_frames=None, + frame_interval=1, + image_size=(None, None), + transform_name=None, + dummy_text_feature=False, + ): + super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None) + self.transform_name = transform_name + self.data["id"] = np.arange(len(self.data)) + self.dummy_text_feature = dummy_text_feature + + def get_data_info(self, index): + T = self.data.iloc[index]["num_frames"] + H = self.data.iloc[index]["height"] + W = self.data.iloc[index]["width"] + return T, H, W + + def getitem(self, index): + # a hack to pass in the (time, height, width) info from sampler + index, num_frames, height, width = [int(val) for val in index.split("-")] + + sample = self.data.iloc[index] + path = sample["path"] + file_type = self.get_type(path) + ar = height / width + + video_fps = 24 # default fps + if file_type == "video": + # loading + vframes, vinfo = read_video(path, backend="av") + video_fps = vinfo["video_fps"] if "video_fps" in vinfo else 24 + + # Sampling video frames + video = temporal_random_crop(vframes, num_frames, self.frame_interval) + + video_fps = video_fps // self.frame_interval + + # transform + transform = get_transforms_video(self.transform_name, (height, width)) + video = transform(video) # T C H W + else: + # loading + image = pil_loader(path) + video_fps = IMG_FPS + + # transform + transform = get_transforms_image(self.transform_name, (height, width)) + image = transform(image) + + # repeat + video = image.unsqueeze(0) + + # TCHW -> CTHW + video = video.permute(1, 0, 2, 3) + ret = { + "video": video, + "num_frames": num_frames, + "height": height, + "width": width, + "ar": ar, + "fps": video_fps, + } + if self.get_text: + ret["text"] = sample["text"] + if self.dummy_text_feature: + text_len = 50 + ret["text"] = torch.zeros((1, text_len, 1152)) + ret["mask"] = text_len + return ret + + def __getitem__(self, index): + return self.getitem(index) + + +@DATASETS.register_module() +class BatchFeatureDataset(torch.utils.data.Dataset): + """ + The dataset is composed of multiple .bin files. + Each .bin file is a list of batch data (like a buffer). All .bin files have the same length. + In each training iteration, one batch is fetched from the current buffer. + Once a buffer is consumed, load another one. + Avoid loading the same .bin on two difference GPUs, i.e., one .bin is assigned to one GPU only. + """ + + def __init__(self, data_path=None): + self.path_list = sorted(glob(data_path + "/**/*.bin")) + + self._len_buffer = len(torch.load(self.path_list[0])) + self._num_buffers = len(self.path_list) + self.num_samples = self.len_buffer * len(self.path_list) + + self.cur_file_idx = -1 + self.cur_buffer = None + + @property + def num_buffers(self): + return self._num_buffers + + @property + def len_buffer(self): + return self._len_buffer + + def _load_buffer(self, idx): + file_idx = idx // self.len_buffer + if file_idx != self.cur_file_idx: + self.cur_file_idx = file_idx + self.cur_buffer = torch.load(self.path_list[file_idx]) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + self._load_buffer(idx) + + batch = self.cur_buffer[idx % self.len_buffer] # dict; keys are {'x', 'fps'} and text related + + ret = { + "video": batch["x"], + "text": batch["y"], + "mask": batch["mask"], + "fps": batch["fps"], + "height": batch["height"], + "width": batch["width"], + "num_frames": batch["num_frames"], + } + return ret diff --git a/src/videogen_hub/pipelines/opensora/opensora/datasets/read_video.py b/src/videogen_hub/pipelines/opensora/opensora/datasets/read_video.py new file mode 100644 index 0000000000000000000000000000000000000000..f988c30622614da6ff0b52f7317f3b26f45b21a5 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/datasets/read_video.py @@ -0,0 +1,188 @@ +import gc +import math +import os +from fractions import Fraction +from typing import Any, Dict, Optional, Tuple, Union + +import av +import cv2 +import numpy as np +import torch +from torchvision.io.video import ( + _align_audio_frames, + _check_av_available, + _log_api_usage_once, + _read_from_stream, + _video_opt, +) + + +def read_video_av( + filename: str, + start_pts: Union[float, Fraction] = 0, + end_pts: Optional[Union[float, Fraction]] = None, + pts_unit: str = "pts", + output_format: str = "THWC", +) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + """ + Reads a video from a file, returning both the video frames and the audio frames + + Args: + filename (str): path to the video file + start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): + The start presentation time of the video + end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): + The end presentation time + pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, + either 'pts' or 'sec'. Defaults to 'pts'. + output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". + + Returns: + vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames + aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points + info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(read_video) + + output_format = output_format.upper() + if output_format not in ("THWC", "TCHW"): + raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") + + from torchvision import get_video_backend + + if not os.path.exists(filename): + raise RuntimeError(f"File not found: {filename}") + + if get_video_backend() != "pyav": + vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit) + else: + _check_av_available() + + if end_pts is None: + end_pts = float("inf") + + if end_pts < start_pts: + raise ValueError( + f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}" + ) + + info = {} + video_frames = [] + audio_frames = [] + audio_timebase = _video_opt.default_timebase + + container = av.open(filename, metadata_errors="ignore") + try: + if container.streams.audio: + audio_timebase = container.streams.audio[0].time_base + if container.streams.video: + video_frames = _read_from_stream( + container, + start_pts, + end_pts, + pts_unit, + container.streams.video[0], + {"video": 0}, + ) + video_fps = container.streams.video[0].average_rate + # guard against potentially corrupted files + if video_fps is not None: + info["video_fps"] = float(video_fps) + + if container.streams.audio: + audio_frames = _read_from_stream( + container, + start_pts, + end_pts, + pts_unit, + container.streams.audio[0], + {"audio": 0}, + ) + info["audio_fps"] = container.streams.audio[0].rate + except av.AVError: + # TODO raise a warning? + pass + finally: + container.close() + del container + # NOTE: manually garbage collect to close pyav threads + gc.collect() + + vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] + aframes_list = [frame.to_ndarray() for frame in audio_frames] + + if vframes_list: + vframes = torch.as_tensor(np.stack(vframes_list)) + else: + vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) + + if aframes_list: + aframes = np.concatenate(aframes_list, 1) + aframes = torch.as_tensor(aframes) + if pts_unit == "sec": + start_pts = int(math.floor(start_pts * (1 / audio_timebase))) + if end_pts != float("inf"): + end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) + aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) + else: + aframes = torch.empty((1, 0), dtype=torch.float32) + + if output_format == "TCHW": + # [T,H,W,C] --> [T,C,H,W] + vframes = vframes.permute(0, 3, 1, 2) + + return vframes, aframes, info + + +def read_video_cv2(video_path): + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + # print("Error: Unable to open video") + raise ValueError + else: + fps = cap.get(cv2.CAP_PROP_FPS) + vinfo = { + "video_fps": fps, + } + + frames = [] + while True: + # Read a frame from the video + ret, frame = cap.read() + + # If frame is not read correctly, break the loop + if not ret: + break + + frames.append(frame[:, :, ::-1]) # BGR to RGB + + # Exit if 'q' is pressed + if cv2.waitKey(25) & 0xFF == ord("q"): + break + + # Release the video capture object and close all windows + cap.release() + cv2.destroyAllWindows() + + frames = np.stack(frames) + frames = torch.from_numpy(frames) # [T, H, W, C=3] + frames = frames.permute(0, 3, 1, 2) + return frames, vinfo + + +def read_video(video_path, backend="av"): + if backend == "cv2": + vframes, vinfo = read_video_cv2(video_path) + elif backend == "av": + vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW") + else: + raise ValueError + + return vframes, vinfo + + +if __name__ == "__main__": + vframes, vinfo = read_video("./data/colors/9.mp4", backend="cv2") + x = 0 diff --git a/src/videogen_hub/pipelines/opensora/opensora/datasets/sampler.py b/src/videogen_hub/pipelines/opensora/opensora/datasets/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..93888c811a549c6f9bb5f7c61f5ab6b9e5ec4257 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/datasets/sampler.py @@ -0,0 +1,322 @@ +from collections import OrderedDict, defaultdict +from pprint import pformat +from typing import Iterator, List, Optional + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DistributedSampler + + +from videogen_hub.pipelines.opensora.opensora.datasets.aspect import get_num_pixels +from videogen_hub.pipelines.opensora.opensora.datasets.bucket import Bucket +from videogen_hub.pipelines.opensora.opensora.datasets.datasets import VariableVideoTextDataset +from videogen_hub.pipelines.opensora.opensora.utils.misc import get_logger, format_numel_str + + +# use pandarallel to accelerate bucket processing +# NOTE: pandarallel should only access local variables +def apply(data, method=None, frame_interval=None, seed=None, num_bucket=None): + return method( + data["num_frames"], + data["height"], + data["width"], + frame_interval, + seed + data["id"] * num_bucket, + ) + + +class StatefulDistributedSampler(DistributedSampler): + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.start_index: int = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def reset(self) -> None: + self.start_index = 0 + + def state_dict(self, step) -> dict: + return {"start_index": step} + + def load_state_dict(self, state_dict: dict) -> None: + self.__dict__.update(state_dict) + + +class VariableVideoBatchSampler(DistributedSampler): + def __init__( + self, + dataset: VariableVideoTextDataset, + bucket_config: dict, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + verbose: bool = False, + num_bucket_build_workers: int = 1, + ) -> None: + super().__init__( + dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last + ) + self.dataset = dataset + self.bucket = Bucket(bucket_config) + self.verbose = verbose + self.last_micro_batch_access_index = 0 + self.approximate_num_batch = None + + self._get_num_batch_cached_bucket_sample_dict = None + self.num_bucket_build_workers = num_bucket_build_workers + + def __iter__(self) -> Iterator[List[int]]: + if self._get_num_batch_cached_bucket_sample_dict is not None: + bucket_sample_dict = self._get_num_batch_cached_bucket_sample_dict + self._get_num_batch_cached_bucket_sample_dict = None + else: + bucket_sample_dict = self.group_by_bucket() + if self.verbose: + self._print_bucket_info(bucket_sample_dict) + + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + bucket_micro_batch_count = OrderedDict() + bucket_last_consumed = OrderedDict() + + # process the samples + for bucket_id, data_list in bucket_sample_dict.items(): + # handle droplast + bs_per_gpu = self.bucket.get_batch_size(bucket_id) + remainder = len(data_list) % bs_per_gpu + + if remainder > 0: + if not self.drop_last: + # if there is remainder, we pad to make it divisible + data_list += data_list[: bs_per_gpu - remainder] + else: + # we just drop the remainder to make it divisible + data_list = data_list[:-remainder] + bucket_sample_dict[bucket_id] = data_list + + # handle shuffle + if self.shuffle: + data_indices = torch.randperm(len(data_list), generator=g).tolist() + data_list = [data_list[i] for i in data_indices] + bucket_sample_dict[bucket_id] = data_list + + # compute how many micro-batches each bucket has + num_micro_batches = len(data_list) // bs_per_gpu + bucket_micro_batch_count[bucket_id] = num_micro_batches + + # compute the bucket access order + # each bucket may have more than one batch of data + # thus bucket_id may appear more than 1 time + bucket_id_access_order = [] + for bucket_id, num_micro_batch in bucket_micro_batch_count.items(): + bucket_id_access_order.extend([bucket_id] * num_micro_batch) + + # randomize the access order + if self.shuffle: + bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist() + bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices] + + # make the number of bucket accesses divisible by dp size + remainder = len(bucket_id_access_order) % self.num_replicas + if remainder > 0: + if self.drop_last: + bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder] + else: + bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder] + + # prepare each batch from its bucket + # according to the predefined bucket access order + num_iters = len(bucket_id_access_order) // self.num_replicas + start_iter_idx = self.last_micro_batch_access_index // self.num_replicas + + # re-compute the micro-batch consumption + # this is useful when resuming from a state dict with a different number of GPUs + self.last_micro_batch_access_index = start_iter_idx * self.num_replicas + for i in range(self.last_micro_batch_access_index): + bucket_id = bucket_id_access_order[i] + bucket_bs = self.bucket.get_batch_size(bucket_id) + if bucket_id in bucket_last_consumed: + bucket_last_consumed[bucket_id] += bucket_bs + else: + bucket_last_consumed[bucket_id] = bucket_bs + + for i in range(start_iter_idx, num_iters): + bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas] + self.last_micro_batch_access_index += self.num_replicas + + # compute the data samples consumed by each access + bucket_access_boundaries = [] + for bucket_id in bucket_access_list: + bucket_bs = self.bucket.get_batch_size(bucket_id) + last_consumed_index = bucket_last_consumed.get(bucket_id, 0) + bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs]) + + # update consumption + if bucket_id in bucket_last_consumed: + bucket_last_consumed[bucket_id] += bucket_bs + else: + bucket_last_consumed[bucket_id] = bucket_bs + + # compute the range of data accessed by each GPU + bucket_id = bucket_access_list[self.rank] + boundary = bucket_access_boundaries[self.rank] + cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]] + + # encode t, h, w into the sample index + real_t, real_h, real_w = self.bucket.get_thw(bucket_id) + cur_micro_batch = [f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch] + yield cur_micro_batch + + self.reset() + + def __len__(self) -> int: + return self.get_num_batch() // dist.get_world_size() + + def group_by_bucket(self) -> dict: + bucket_sample_dict = OrderedDict() + + from pandarallel import pandarallel + + pandarallel.initialize(nb_workers=self.num_bucket_build_workers, progress_bar=False) + get_logger().info("Building buckets...") + bucket_ids = self.dataset.data.parallel_apply( + apply, + axis=1, + method=self.bucket.get_bucket_id, + frame_interval=self.dataset.frame_interval, + seed=self.seed + self.epoch, + num_bucket=self.bucket.num_bucket, + ) + + # group by bucket + # each data sample is put into a bucket with a similar image/video size + for i in range(len(self.dataset)): + bucket_id = bucket_ids[i] + if bucket_id is None: + continue + if bucket_id not in bucket_sample_dict: + bucket_sample_dict[bucket_id] = [] + bucket_sample_dict[bucket_id].append(i) + return bucket_sample_dict + + def get_num_batch(self) -> int: + bucket_sample_dict = self.group_by_bucket() + self._get_num_batch_cached_bucket_sample_dict = bucket_sample_dict + + # calculate the number of batches + if self.verbose: + self._print_bucket_info(bucket_sample_dict) + return self.approximate_num_batch + + def _print_bucket_info(self, bucket_sample_dict: dict) -> None: + # collect statistics + total_samples = 0 + total_batch = 0 + num_aspect_dict = defaultdict(lambda: [0, 0]) + num_hwt_dict = defaultdict(lambda: [0, 0]) + for k, v in bucket_sample_dict.items(): + size = len(v) + num_batch = size // self.bucket.get_batch_size(k[:-1]) + + total_samples += size + total_batch += num_batch + + num_aspect_dict[k[-1]][0] += size + num_aspect_dict[k[-1]][1] += num_batch + num_hwt_dict[k[:-1]][0] += size + num_hwt_dict[k[:-1]][1] += num_batch + + # sort + num_aspect_dict = dict(sorted(num_aspect_dict.items(), key=lambda x: x[0])) + num_hwt_dict = dict( + sorted(num_hwt_dict.items(), key=lambda x: (get_num_pixels(x[0][0]), x[0][1]), reverse=True) + ) + num_hwt_img_dict = {k: v for k, v in num_hwt_dict.items() if k[1] == 1} + num_hwt_vid_dict = {k: v for k, v in num_hwt_dict.items() if k[1] > 1} + + # log + if dist.get_rank() == 0 and self.verbose: + get_logger().info("Bucket Info:") + get_logger().info( + "Bucket [#sample, #batch] by aspect ratio:\n%s", pformat(num_aspect_dict, sort_dicts=False) + ) + get_logger().info( + "Image Bucket [#sample, #batch] by HxWxT:\n%s", pformat(num_hwt_img_dict, sort_dicts=False) + ) + get_logger().info( + "Video Bucket [#sample, #batch] by HxWxT:\n%s", pformat(num_hwt_vid_dict, sort_dicts=False) + ) + get_logger().info( + "#training batch: %s, #training sample: %s, #non empty bucket: %s", + format_numel_str(total_batch), + format_numel_str(total_samples), + len(bucket_sample_dict), + ) + self.approximate_num_batch = total_batch + + def reset(self): + self.last_micro_batch_access_index = 0 + + def state_dict(self, num_steps: int) -> dict: + # the last_micro_batch_access_index in the __iter__ is often + # not accurate during multi-workers and data prefetching + # thus, we need the user to pass the actual steps which have been executed + # to calculate the correct last_micro_batch_access_index + return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas} + + def load_state_dict(self, state_dict: dict) -> None: + self.__dict__.update(state_dict) + + +class BatchDistributedSampler(DistributedSampler): + """ + Used with BatchDataset; + Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then + | buffer {i} | buffer {i+1} + ------ | ------------------- | ------------------- + rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9 + rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19 + rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29 + """ + + def __init__(self, dataset: Dataset, **kwargs): + super().__init__(dataset, **kwargs) + self.start_index = 0 + + def __iter__(self): + num_buffers = self.dataset.num_buffers + len_buffer = self.dataset.len_buffer + num_buffers_i = num_buffers // self.num_replicas + num_samples_i = len_buffer * num_buffers_i + + indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i + indices_i = indices_i.tolist() + + return iter(indices_i) + + def reset(self): + self.start_index = 0 + + def state_dict(self, step) -> dict: + return {"start_index": step} + + def load_state_dict(self, state_dict: dict): + self.start_index = state_dict["start_index"] + 1 diff --git a/src/videogen_hub/pipelines/opensora/opensora/datasets/utils.py b/src/videogen_hub/pipelines/opensora/opensora/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b7083a3dd6c684d014e45d53293e51a8bd152d73 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/datasets/utils.py @@ -0,0 +1,217 @@ +import os +import re + +import numpy as np +import pandas as pd +import requests +import torch +import torchvision +import torchvision.transforms as transforms +from PIL import Image +from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader +from torchvision.io import write_video +from torchvision.utils import save_image + +from . import video_transforms + +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") + +regex = re.compile( + r"^(?:http|ftp)s?://" # http:// or https:// + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain... + r"localhost|" # localhost... + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip + r"(?::\d+)?" # optional port + r"(?:/?|[/?]\S+)$", + re.IGNORECASE, +) + + +def is_img(path): + ext = os.path.splitext(path)[-1].lower() + return ext in IMG_EXTENSIONS + + +def is_vid(path): + ext = os.path.splitext(path)[-1].lower() + return ext in VID_EXTENSIONS + + +def is_url(url): + return re.match(regex, url) is not None + + +def read_file(input_path): + if input_path.endswith(".csv"): + return pd.read_csv(input_path) + elif input_path.endswith(".parquet"): + return pd.read_parquet(input_path) + else: + raise NotImplementedError(f"Unsupported file format: {input_path}") + + +def download_url(input_path): + output_dir = "cache" + os.makedirs(output_dir, exist_ok=True) + base_name = os.path.basename(input_path) + output_path = os.path.join(output_dir, base_name) + img_data = requests.get(input_path).content + with open(output_path, "wb") as handler: + handler.write(img_data) + print(f"URL {input_path} downloaded to {output_path}") + return output_path + + +def temporal_random_crop(vframes, num_frames, frame_interval): + temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) + total_frames = len(vframes) + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + assert ( + end_frame_ind - start_frame_ind >= num_frames + ), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}" + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int) + video = vframes[frame_indice] + return video + + +def get_transforms_video(name="center", image_size=(256, 256)): + if name is None: + return None + elif name == "center": + assert image_size[0] == image_size[1], "image_size must be square for center crop" + transform_video = transforms.Compose( + [ + video_transforms.ToTensorVideo(), # TCHW + # video_transforms.RandomHorizontalFlipVideo(), + video_transforms.UCFCenterCropVideo(image_size[0]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + elif name == "resize_crop": + transform_video = transforms.Compose( + [ + video_transforms.ToTensorVideo(), # TCHW + video_transforms.ResizeCrop(image_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + else: + raise NotImplementedError(f"Transform {name} not implemented") + return transform_video + + +def get_transforms_image(name="center", image_size=(256, 256)): + if name is None: + return None + elif name == "center": + assert image_size[0] == image_size[1], "Image size must be square for center crop" + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])), + # transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + elif name == "resize_crop": + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + else: + raise NotImplementedError(f"Transform {name} not implemented") + return transform + + +def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)): + image = pil_loader(path) + if transform is None: + transform = get_transforms_image(image_size=image_size, name=transform_name) + image = transform(image) + video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1) + video = video.permute(1, 0, 2, 3) + return video + + +def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)): + vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + if transform is None: + transform = get_transforms_video(image_size=image_size, name=transform_name) + video = transform(vframes) # T C H W + video = video.permute(1, 0, 2, 3) + return video + + +def read_from_path(path, image_size, transform_name="center"): + if is_url(path): + path = download_url(path) + ext = os.path.splitext(path)[-1].lower() + if ext.lower() in VID_EXTENSIONS: + return read_video_from_path(path, image_size=image_size, transform_name=transform_name) + else: + assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + return read_image_from_path(path, image_size=image_size, transform_name=transform_name) + + +def save_sample(x, save_path=None, fps=8, normalize=True, value_range=(-1, 1), force_video=False, verbose=True): + """ + Args: + x (Tensor): shape [C, T, H, W] + """ + assert x.ndim == 4 + + if not force_video and x.shape[1] == 1: # T = 1: save as image + save_path += ".png" + x = x.squeeze(1) + save_image([x], save_path, normalize=normalize, value_range=value_range) + else: + save_path += ".mp4" + if normalize: + low, high = value_range + x.clamp_(min=low, max=high) + x.sub_(low).div_(max(high - low, 1e-5)) + + x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8) + write_video(save_path, x, fps=fps, video_codec="h264") + if verbose: + print(f"Saved to {save_path}") + return save_path + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) + + +def resize_crop_to_fill(pil_image, image_size): + w, h = pil_image.size # PIL is (W, H) + th, tw = image_size + rh, rw = th / h, tw / w + if rh > rw: + sh, sw = th, round(w * rh) + image = pil_image.resize((sw, sh), Image.BICUBIC) + i = 0 + j = int(round((sw - tw) / 2.0)) + else: + sh, sw = round(h * rw), tw + image = pil_image.resize((sw, sh), Image.BICUBIC) + i = int(round((sh - th) / 2.0)) + j = 0 + arr = np.array(image) + assert i + th <= arr.shape[0] and j + tw <= arr.shape[1] + return Image.fromarray(arr[i : i + th, j : j + tw]) diff --git a/src/videogen_hub/pipelines/opensora/opensora/datasets/video_transforms.py b/src/videogen_hub/pipelines/opensora/opensora/datasets/video_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf50468ee96e339c0dfb6401e83a8f34c29b900 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/datasets/video_transforms.py @@ -0,0 +1,520 @@ +# Copyright 2024 Vchitect/Latte + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Modified from Latte + +# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py + + +import numbers +import random + +import numpy as np +import torch + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i : i + h, j : j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) + + +def resize_scale(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size[0] / min(H, W) + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def center_crop_using_short_edge(clip): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + if h < w: + th, tw = h, h + i = 0 + j = int(round((w - tw) / 2.0)) + else: + th, tw = w, w + i = int(round((h - th) / 2.0)) + j = 0 + return crop(clip, i, j, th, tw) + + +def resize_crop_to_fill(clip, target_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = target_size[0], target_size[1] + rh, rw = th / h, tw / w + if rh > rw: + sh, sw = th, round(w * rh) + clip = resize(clip, (sh, sw), "bilinear") + i = 0 + j = int(round(sw - tw) / 2.0) + else: + sh, sw = round(h * rw), tw + clip = resize(clip, (sh, sw), "bilinear") + i = int(round(sh - th) / 2.0) + j = 0 + assert i + th <= clip.size(-2) and j + tw <= clip.size(-1) + return crop(clip, i, j, th, tw) + + +def random_shift_crop(clip): + """ + Slide along the long edge, with the short edge as crop size + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + + if h <= w: + short_edge = h + else: + short_edge = w + + th, tw = short_edge, short_edge + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + # print(mean) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + Returns: + flipped clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) + + +class ResizeCrop: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + clip = resize_crop_to_fill(clip, self.size) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class RandomCropVideo: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: randomly cropped video clip. + size is (T, C, OH, OW) + """ + i, j, h, w = self.get_params(clip) + return crop(clip, i, j, h, w) + + def get_params(self, clip): + h, w = clip.shape[-2:] + th, tw = self.size + + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + + return i, j, th, tw + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class CenterCropResizeVideo: + """ + First use the short side for cropping length, + center crop video, then resize to the specified size + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop_using_short_edge(clip) + clip_center_crop_resize = resize( + clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode + ) + return clip_center_crop_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class UCFCenterCropVideo: + """ + First scale to the specified size in equal proportion to the short edge, + then center cropping + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + clip_center_crop = center_crop(clip_resize, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class KineticsRandomCropResizeVideo: + """ + Slide along the long edge, with the short edge as crop size. And resie to the desired size. + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + clip_random_crop = random_shift_crop(clip) + clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) + return clip_resize + + +class CenterCropVideo: + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop(clip, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class NormalizeVideo: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) + """ + return normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class RandomHorizontalFlipVideo: + """ + Flip the video clip along the horizontal direction with a given probability + Args: + p (float): probability of the clip being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Size is (T, C, H, W) + Return: + clip (torch.tensor): Size is (T, C, H, W) + """ + if random.random() < self.p: + clip = hflip(clip) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +# ------------------------------------------------------------ +# --------------------- Sampling --------------------------- +# ------------------------------------------------------------ +class TemporalRandomCrop(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, total_frames): + rand_end = max(0, total_frames - self.size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + self.size, total_frames) + return begin_index, end_index + + +if __name__ == "__main__": + import os + + import numpy as np + import torchvision.io as io + from torchvision import transforms + from torchvision.utils import save_image + + vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW") + + trans = transforms.Compose( + [ + ToTensorVideo(), + RandomHorizontalFlipVideo(), + UCFCenterCropVideo(512), + # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + target_video_len = 32 + frame_interval = 1 + total_frames = len(vframes) + print(total_frames) + + temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) + + # Sampling video frames + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + # print(start_frame_ind) + # print(end_frame_ind) + assert end_frame_ind - start_frame_ind >= target_video_len + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) + print(frame_indice) + + select_vframes = vframes[frame_indice] + print(select_vframes.shape) + print(select_vframes.dtype) + + select_vframes_trans = trans(select_vframes) + print(select_vframes_trans.shape) + print(select_vframes_trans.dtype) + + select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) + print(select_vframes_trans_int.dtype) + print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) + + io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) + + for i in range(target_video_len): + save_image( + select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1) + ) diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca80c956d3cc38800b2b2931afbcba0f0c2d6dd8 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/__init__.py @@ -0,0 +1,6 @@ +from videogen_hub.pipelines.opensora.opensora.models.dit import * +from videogen_hub.pipelines.opensora.opensora.models.latte import * +from videogen_hub.pipelines.opensora.opensora.models.pixart import * +from videogen_hub.pipelines.opensora.opensora.models.stdit import * +from videogen_hub.pipelines.opensora.opensora.models.text_encoder import * +from videogen_hub.pipelines.opensora.opensora.models.vae import * diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/dit/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/models/dit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf43e0ce8808ac0fda0227fa967b7c32c39ec93a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/dit/__init__.py @@ -0,0 +1 @@ +from videogen_hub.pipelines.opensora.opensora.models.dit.dit import DiT, DiT_XL_2, DiT_XL_2x2 diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/dit/dit.py b/src/videogen_hub/pipelines/opensora/opensora/models/dit/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..ccbd722c6e1ee83c51f27e9a247b4a4d4e1f7fce --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/dit/dit.py @@ -0,0 +1,288 @@ +# Modified from Meta DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange +from timm.models.vision_transformer import Mlp + +from videogen_hub.pipelines.opensora.opensora.models.layers.blocks import ( + Attention, + CaptionEmbedder, + FinalLayer, + LabelEmbedder, + PatchEmbed3D, + TimestepEmbedder, + approx_gelu, + get_1d_sincos_pos_embed, + get_2d_sincos_pos_embed, + get_layernorm, + modulate, +) +from videogen_hub.pipelines.opensora.opensora.acceleration.checkpoint import auto_grad_checkpoint +from videogen_hub.pipelines.opensora.opensora.registry import MODELS +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import load_checkpoint + + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + enable_flash_attn=False, + enable_layernorm_kernel=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.enable_flash_attn = enable_flash_attn + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = Attention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flash_attn=enable_flash_attn, + ) + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1, x, shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2, x, shift_mlp, scale_mlp)) + return x + + +@MODELS.register_module() +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=(16, 32, 32), + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + learn_sigma=True, + condition="text", + no_temporal_pos_emb=False, + caption_channels=512, + model_max_length=77, + dtype=torch.float32, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.input_size = input_size + num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) + self.num_patches = num_patches + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = num_patches // self.num_temporal + self.num_heads = num_heads + self.dtype = dtype + self.use_text_encoder = not condition.startswith("label") + if enable_flash_attn: + assert dtype in [ + torch.float16, + torch.bfloat16, + ], f"Flash attention only supports float16 and bfloat16, but got {self.dtype}" + self.no_temporal_pos_emb = no_temporal_pos_emb + self.mlp_ratio = mlp_ratio + self.depth = depth + assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT" + + self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed()) + self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, embed_dim=hidden_size) + if not self.use_text_encoder: + num_classes = int(condition.split("_")[-1]) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + else: + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=1, # pooled token + ) + self.t_embedder = TimestepEmbedder(hidden_size) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + enable_flash_attn=enable_flash_attn, + enable_layernorm_kernel=enable_layernorm_kernel, + ) + for _ in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels) + + self.initialize_weights() + self.enable_flash_attn = enable_flash_attn + self.enable_layernorm_kernel = enable_layernorm_kernel + + def get_spatial_pos_embed(self): + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + self.input_size[1] // self.patch_size[1], + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def get_temporal_pos_embed(self): + pos_embed = get_1d_sincos_pos_embed( + self.hidden_size, + self.input_size[0] // self.patch_size[0], + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def unpatchify(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def forward(self, x, t, y): + """ + Forward pass of DiT. + x: (B, C, T, H, W) tensor of inputs + t: (B,) tensor of diffusion timesteps + y: list of text + """ + # origin inputs should be float32, cast to specified dtype + x = x.to(self.dtype) + if self.use_text_encoder: + y = y.to(self.dtype) + + # embedding + x = self.x_embedder(x) # (B, N, D) + x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) + x = x + self.pos_embed_spatial + if not self.no_temporal_pos_emb: + x = rearrange(x, "b t s d -> b s t d") + x = x + self.pos_embed_temporal + x = rearrange(x, "b s t d -> b (t s) d") + else: + x = rearrange(x, "b t s d -> b (t s) d") + + t = self.t_embedder(t, dtype=x.dtype) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + if self.use_text_encoder: + y = y.squeeze(1).squeeze(1) + condition = t + y + + # blocks + for _, block in enumerate(self.blocks): + c = condition + x = auto_grad_checkpoint(block, x, c) # (B, N, D) + + # final process + x = self.final_layer(x, condition) # (B, N, num_patches * out_channels) + x = self.unpatchify(x) # (B, out_channels, T, H, W) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + if module.weight.requires_grad_: + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + # Zero-out text embedding layers: + if self.use_text_encoder: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + +@MODELS.register_module("DiT-XL/2") +def DiT_XL_2(from_pretrained=None, **kwargs): + model = DiT( + depth=28, + hidden_size=1152, + patch_size=(1, 2, 2), + num_heads=16, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model + + +@MODELS.register_module("DiT-XL/2x2") +def DiT_XL_2x2(from_pretrained=None, **kwargs): + model = DiT( + depth=28, + hidden_size=1152, + patch_size=(2, 2, 2), + num_heads=16, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/latte/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/models/latte/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11a4e77a5aa5eec1afd5a65948e2601b6a26c476 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/latte/__init__.py @@ -0,0 +1 @@ +from videogen_hub.pipelines.opensora.opensora.models.latte.latte import Latte, Latte_XL_2, Latte_XL_2x2 diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/latte/latte.py b/src/videogen_hub/pipelines/opensora/opensora/models/latte/latte.py new file mode 100644 index 0000000000000000000000000000000000000000..abc3359f028774b29f800c5899978924565f2f40 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/latte/latte.py @@ -0,0 +1,112 @@ +# Copyright 2024 Vchitect/Latte +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Modified from Latte +# +# +# This file is mofied from https://github.com/Vchitect/Latte/blob/main/models/latte.py +# +# With references to: +# Latte: https://github.com/Vchitect/Latte +# DiT: https://github.com/facebookresearch/DiT/tree/main + + +import torch +from einops import rearrange, repeat + +from videogen_hub.pipelines.opensora.opensora.acceleration.checkpoint import auto_grad_checkpoint +from videogen_hub.pipelines.opensora.opensora.models import DiT +from videogen_hub.pipelines.opensora.opensora.registry import MODELS +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import load_checkpoint + + +@MODELS.register_module() +class Latte(DiT): + def forward(self, x, t, y): + """ + Forward pass of DiT. + x: (B, C, T, H, W) tensor of inputs + t: (B,) tensor of diffusion timesteps + y: list of text + """ + # origin inputs should be float32, cast to specified dtype + x = x.to(self.dtype) + + # embedding + x = self.x_embedder(x) # (B, N, D) + x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) + x = x + self.pos_embed_spatial + x = rearrange(x, "b t s d -> b (t s) d") + + t = self.t_embedder(t, dtype=x.dtype) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + if self.use_text_encoder: + y = y.squeeze(1).squeeze(1) + condition = t + y + condition_spatial = repeat(condition, "b d -> (b t) d", t=self.num_temporal) + condition_temporal = repeat(condition, "b d -> (b s) d", s=self.num_spatial) + + # blocks + for i, block in enumerate(self.blocks): + if i % 2 == 0: + # spatial + x = rearrange(x, "b (t s) d -> (b t) s d", t=self.num_temporal, s=self.num_spatial) + c = condition_spatial + else: + # temporal + x = rearrange(x, "b (t s) d -> (b s) t d", t=self.num_temporal, s=self.num_spatial) + c = condition_temporal + if i == 1: + x = x + self.pos_embed_temporal + + x = auto_grad_checkpoint(block, x, c) # (B, N, D) + + if i % 2 == 0: + x = rearrange(x, "(b t) s d -> b (t s) d", t=self.num_temporal, s=self.num_spatial) + else: + x = rearrange(x, "(b s) t d -> b (t s) d", t=self.num_temporal, s=self.num_spatial) + + # final process + x = self.final_layer(x, condition) # (B, N, num_patches * out_channels) + x = self.unpatchify(x) # (B, out_channels, T, H, W) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + +@MODELS.register_module("Latte-XL/2") +def Latte_XL_2(from_pretrained=None, **kwargs): + model = Latte( + depth=28, + hidden_size=1152, + patch_size=(1, 2, 2), + num_heads=16, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model + + +@MODELS.register_module("Latte-XL/2x2") +def Latte_XL_2x2(from_pretrained=None, **kwargs): + model = Latte( + depth=28, + hidden_size=1152, + patch_size=(2, 2, 2), + num_heads=16, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/layers/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/layers/blocks.py b/src/videogen_hub/pipelines/opensora/opensora/models/layers/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..d91210742c64aa61e318bcc0dd98b5a9729c29ff --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/layers/blocks.py @@ -0,0 +1,864 @@ +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# Latte: https://github.com/Vchitect/Latte +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import functools +import math +from typing import Optional + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import xformers.ops +from einops import rearrange +from timm.models.vision_transformer import Mlp + +from videogen_hub.pipelines.opensora.opensora.acceleration.communications import split_forward_gather_backward +from videogen_hub.pipelines.opensora.opensora.acceleration.communications import all_to_all +from videogen_hub.pipelines.opensora.opensora.acceleration.parallel_states import get_sequence_parallel_group + +approx_gelu = lambda: nn.GELU(approximate="tanh") + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool): + if use_kernel: + try: + from apex.normalization import FusedLayerNorm + + return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps) + except ImportError: + raise RuntimeError("FusedLayerNorm not available. Please install apex.") + else: + return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine) + + +def modulate(norm_func, x, shift, scale): + # Suppose x is (B, N, D), shift is (B, D), scale is (B, D) + dtype = x.dtype + x = norm_func(x.to(torch.float32)).to(dtype) + x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1) + x = x.to(dtype) + return x + + +def t2i_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +# =============================================== +# General-purpose Layers +# =============================================== + + +class PatchEmbed3D(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # (B C T H W) + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC + return x + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = LlamaRMSNorm, + enable_flash_attn: bool = False, + rope=None, + qk_norm_legacy: bool = False, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.enable_flash_attn = enable_flash_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.qk_norm_legacy = qk_norm_legacy + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = False + if rope is not None: + self.rope = True + self.rotary_emb = rope + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + # flash attn is not memory efficient for small sequences, this is empirical + enable_flash_attn = self.enable_flash_attn and (N > B) + qkv = self.qkv(x) + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + + qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + if self.qk_norm_legacy: + # WARNING: this may be a bug + if self.rope: + q = self.rotary_emb(q) + k = self.rotary_emb(k) + q, k = self.q_norm(q), self.k_norm(k) + else: + q, k = self.q_norm(q), self.k_norm(k) + if self.rope: + q = self.rotary_emb(q) + k = self.rotary_emb(k) + + if enable_flash_attn: + from flash_attn import flash_attn_func + + # (B, #heads, N, #dim) -> (B, N, #heads, #dim) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + x = flash_attn_func( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + softmax_scale=self.scale, + ) + else: + dtype = q.dtype + q = q * self.scale + attn = q @ k.transpose(-2, -1) # translate attn to float32 + attn = attn.to(torch.float32) + attn = attn.softmax(dim=-1) + attn = attn.to(dtype) # cast back attn to original dtype + attn = self.attn_drop(attn) + x = attn @ v + + x_output_shape = (B, N, C) + if not enable_flash_attn: + x = x.transpose(1, 2) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class KVCompressAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = LlamaRMSNorm, + enable_flash_attn: bool = False, + sampling="conv", + sr_ratio=1, + mem_eff_attention=False, + attn_half=False, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.enable_flash_attn = enable_flash_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + self.sr_ratio = sr_ratio + self.sampling = sampling + if sr_ratio > 1 and sampling == "conv": + # Avg Conv Init. + self.sr = nn.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio) + self.sr.weight.data.fill_(1 / sr_ratio**2) + self.sr.bias.data.zero_() + self.norm = nn.LayerNorm(dim) + + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.mem_eff_attention = mem_eff_attention + self.attn_half = attn_half + + def downsample_2d(self, tensor, H, W, scale_factor, sampling=None): + if sampling is None or scale_factor == 1: + return tensor + B, N, C = tensor.shape + + if sampling == "uniform_every": + return tensor[:, ::scale_factor], int(N // scale_factor) + + tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2) + new_H, new_W = int(H / scale_factor), int(W / scale_factor) + new_N = new_H * new_W + + if sampling == "ave": + tensor = F.interpolate(tensor, scale_factor=1 / scale_factor, mode="nearest").permute(0, 2, 3, 1) + elif sampling == "uniform": + tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1) + elif sampling == "conv": + tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1) + tensor = self.norm(tensor) + else: + raise ValueError + + return tensor.reshape(B, new_N, C).contiguous(), new_N + + def forward(self, x: torch.Tensor, mask=None, HW=None, block_id=None, **kwargs) -> torch.Tensor: + B, N, C = x.shape + new_N = N + H, W = HW + # flash attn is not memory efficient for small sequences, this is empirical + enable_flash_attn = self.enable_flash_attn and (N > B) + + qkv = self.qkv(x).reshape(B, N, 3, C) + q, k, v = qkv.unbind(2) + dtype = q.dtype + # KV compression + if self.sr_ratio > 1: + k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling) + v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling) + + q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype) + k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + + q, k = self.q_norm(q), self.k_norm(k) + + if enable_flash_attn: + from flash_attn import flash_attn_func + + x = flash_attn_func( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + softmax_scale=self.scale, + ) + + elif self.mem_eff_attention: + attn_bias = None + if mask is not None: + attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) + attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf")) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + else: + # (B, N, #heads, #dim) -> (B, #heads, N, #dim) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + dtype = q.dtype + q = q * self.scale + attn = q @ k.transpose(-2, -1) # translate attn to float32 + if not self.attn_half: + attn = attn.to(torch.float32) + attn = attn.softmax(dim=-1) + attn = attn.to(dtype) # cast back attn to original dtype + attn = self.attn_drop(attn) + x = attn @ v + + x_output_shape = (B, N, C) + if not enable_flash_attn: + x = x.transpose(1, 2) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SeqParallelAttention(Attention): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = LlamaRMSNorm, + enable_flash_attn: bool = False, + rope=None, + ) -> None: + assert rope is None, "Rope is not supported in SeqParallelAttention" + super().__init__( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + enable_flash_attn=enable_flash_attn, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape # for sequence parallel here, the N is a local sequence length + qkv = self.qkv(x) + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.view(qkv_shape) + + sp_group = get_sequence_parallel_group() + + # apply all_to_all to gather sequence and split attention heads + # [B, SUB_N, 3, NUM_HEAD, HEAD_DIM] -> [B, N, 3, NUM_HEAD_PER_DEVICE, HEAD_DIM] + qkv = all_to_all(qkv, sp_group, scatter_dim=3, gather_dim=1) + + if self.enable_flash_attn: + qkv_permute_shape = ( + 2, + 0, + 1, + 3, + 4, + ) # [3, B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM] + else: + qkv_permute_shape = ( + 2, + 0, + 3, + 1, + 4, + ) # [3, B, NUM_HEAD_PER_DEVICE, N, HEAD_DIM] + qkv = qkv.permute(qkv_permute_shape) + + # ERROR: Should qk_norm first + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + if self.enable_flash_attn: + from flash_attn import flash_attn_func + + x = flash_attn_func( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + softmax_scale=self.scale, + ) + else: + dtype = q.dtype + q = q * self.scale + attn = q @ k.transpose(-2, -1) # translate attn to float32 + attn = attn.to(torch.float32) + attn = attn.softmax(dim=-1) + attn = attn.to(dtype) # cast back attn to original dtype + attn = self.attn_drop(attn) + x = attn @ v + + if not self.enable_flash_attn: + x = x.transpose(1, 2) + + # apply all to all to gather back attention heads and split sequence + # [B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM] -> [B, SUB_N, NUM_HEAD, HEAD_DIM] + x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2) + + # reshape outputs back to [B, N, C] + x_output_shape = (B, N, C) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0): + super(MultiHeadCrossAttention, self).__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + + self.q_linear = nn.Linear(d_model, d_model) + self.kv_linear = nn.Linear(d_model, d_model * 2) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(d_model, d_model) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + B, N, C = x.shape + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + + x = x.view(B, -1, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention): + def __init__( + self, + d_model, + num_heads, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__( + d_model=d_model, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + sp_group = get_sequence_parallel_group() + sp_size = dist.get_world_size(sp_group) + B, SUB_N, C = x.shape # [B, TS/p, C] + N = SUB_N * sp_size + + # shape: + # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down") + k, v = kv.unbind(2) + + # apply all_to_all to gather sequence and split attention heads + q = all_to_all(q, sp_group, scatter_dim=2, gather_dim=1) + + q = q.view(1, -1, self.num_heads // sp_size, self.head_dim) + k = k.view(1, -1, self.num_heads // sp_size, self.head_dim) + v = v.view(1, -1, self.num_heads // sp_size, self.head_dim) + + # compute attention + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + + # apply all to all to gather back attention heads and scatter sequence + x = x.view(B, -1, self.num_heads // sp_size, self.head_dim) + x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2) + + # apply output projection + x = x.view(B, -1, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, num_patch, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final, x, shift, scale) + x = self.linear(x) + return x + + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) + self.out_channels = out_channels + self.d_t = d_t + self.d_s = d_s + + def t_mask_select(self, x_mask, x, masked_x, T, S): + # x: [B, (T, S), C] + # mased_x: [B, (T, S), C] + # x_mask: [B, T] + x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) + masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S) + x = torch.where(x_mask[:, :, None, None], x, masked_x) + x = rearrange(x, "B T S C -> B (T S) C") + return x + + def forward(self, x, t, x_mask=None, t0=None, T=None, S=None): + if T is None: + T = self.d_t + if S is None: + S = self.d_s + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = t2i_modulate(self.norm_final(x), shift, scale) + if x_mask is not None: + shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1) + x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero) + x = self.t_mask_select(x_mask, x, x_zero, T, S) + x = self.linear(x) + return x + + +# =============================================== +# Embedding Layers for Timesteps and Class Labels +# =============================================== + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + freqs = freqs.to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if t_freq.dtype != dtype: + t_freq = t_freq.to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + return self.embedding_table(labels) + + +class SizeEmbedder(TimestepEmbedder): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.outdim = hidden_size + + def forward(self, s, bs): + if s.ndim == 1: + s = s[:, None] + assert s.ndim == 2 + if s.shape[0] != bs: + s = s.repeat(bs // s.shape[0], 1) + assert s.shape[0] == bs + b, dims = s.shape[0], s.shape[1] + s = rearrange(s, "b d -> (b d)") + s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) + s_emb = self.mlp(s_freq) + s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return s_emb + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__( + self, + in_channels, + hidden_size, + uncond_prob, + act_layer=nn.GELU(approximate="tanh"), + token_num=120, + ): + super().__init__() + self.y_proj = Mlp( + in_features=in_channels, + hidden_features=hidden_size, + out_features=hidden_size, + act_layer=act_layer, + drop=0, + ) + self.register_buffer( + "y_embedding", + torch.randn(token_num, in_channels) / in_channels**0.5, + ) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + if train: + assert caption.shape[2:] == self.y_embedding.shape + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + caption = self.y_proj(caption) + return caption + + +class PositionEmbedding2D(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + assert dim % 4 == 0, "dim must be divisible by 4" + half_dim = dim // 2 + inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def _get_sin_cos_emb(self, t: torch.Tensor): + out = torch.einsum("i,d->id", t, self.inv_freq) + emb_cos = torch.cos(out) + emb_sin = torch.sin(out) + return torch.cat((emb_sin, emb_cos), dim=-1) + + @functools.lru_cache(maxsize=512) + def _get_cached_emb( + self, + device: torch.device, + dtype: torch.dtype, + h: int, + w: int, + scale: float = 1.0, + base_size: Optional[int] = None, + ): + grid_h = torch.arange(h, device=device) / scale + grid_w = torch.arange(w, device=device) / scale + if base_size is not None: + grid_h *= base_size / h + grid_w *= base_size / w + grid_h, grid_w = torch.meshgrid( + grid_w, + grid_h, + indexing="ij", + ) # here w goes first + grid_h = grid_h.t().reshape(-1) + grid_w = grid_w.t().reshape(-1) + emb_h = self._get_sin_cos_emb(grid_h) + emb_w = self._get_sin_cos_emb(grid_w) + return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype) + + def forward( + self, + x: torch.Tensor, + h: int, + w: int, + scale: Optional[float] = 1.0, + base_size: Optional[int] = None, + ) -> torch.Tensor: + return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size) + + +# =============================================== +# Sine/Cosine Positional Embedding Functions +# =============================================== +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if not isinstance(grid_size, tuple): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / scale + if base_size is not None: + grid_h *= base_size / grid_size[0] + grid_w *= base_size / grid_size[1] + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0): + pos = np.arange(0, length)[..., None] / scale + return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/pixart/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/models/pixart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9088f2edbdbf818a6810b412805fa0c2c8760466 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/pixart/__init__.py @@ -0,0 +1,2 @@ +from videogen_hub.pipelines.opensora.opensora.models.pixart.pixart import PixArt, PixArt_1B_2, PixArt_XL_2 +from videogen_hub.pipelines.opensora.opensora.models.pixart.pixart_sigma import PixArt_Sigma_XL_2 diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/pixart/pixart.py b/src/videogen_hub/pipelines/opensora/opensora/models/pixart/pixart.py new file mode 100644 index 0000000000000000000000000000000000000000..3e844c9c1f52033e28fdacbf4f01ea62b0e695e4 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/pixart/pixart.py @@ -0,0 +1,403 @@ +# Adapted from PixArt +# +# Copyright (C) 2023 PixArt-alpha/PixArt-alpha +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# DiT: https://github.com/facebookresearch/DiT/tree/main +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + +# from .builder import MODELS +from videogen_hub.pipelines.opensora.opensora.acceleration.checkpoint import auto_grad_checkpoint +from videogen_hub.pipelines.opensora.opensora.models.layers.blocks import ( + Attention, + CaptionEmbedder, + MultiHeadCrossAttention, + PatchEmbed3D, + SeqParallelAttention, + SeqParallelMultiHeadCrossAttention, + SizeEmbedder, + T2IFinalLayer, + TimestepEmbedder, + approx_gelu, + get_1d_sincos_pos_embed, + get_2d_sincos_pos_embed, + get_layernorm, + t2i_modulate, +) +from videogen_hub.pipelines.opensora.opensora.registry import MODELS +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import load_checkpoint + + +class PixArtBlock(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + drop_path=0.0, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.enable_flash_attn = enable_flash_attn + self._enable_sequence_parallelism = enable_sequence_parallelism + + if enable_sequence_parallelism: + self.attn_cls = SeqParallelAttention + self.mha_cls = SeqParallelMultiHeadCrossAttention + else: + self.attn_cls = Attention + self.mha_cls = MultiHeadCrossAttention + + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = self.attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flash_attn=enable_flash_attn, + ) + self.cross_attn = self.mha_cls(hidden_size, num_heads) + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward(self, x, y, t, mask=None): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +@MODELS.register_module() +class PixArt(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=(1, 32, 32), + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path: float = 0.0, + no_temporal_pos_emb=False, + caption_channels=4096, + model_max_length=120, + dtype=torch.float32, + freeze=None, + space_scale=1.0, + time_scale=1.0, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + base_size=None, + ): + super().__init__() + assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in this version." + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.input_size = input_size + num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) + self.num_patches = num_patches + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = num_patches // self.num_temporal + if base_size is None: + self.base_size = int(np.sqrt(self.num_spatial)) + else: + self.base_size = base_size // patch_size[1] + self.num_heads = num_heads + self.dtype = dtype + self.no_temporal_pos_emb = no_temporal_pos_emb + self.depth = depth + self.mlp_ratio = mlp_ratio + self.enable_flash_attn = enable_flash_attn + self.enable_layernorm_kernel = enable_layernorm_kernel + self.space_scale = space_scale + self.time_scale = time_scale + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=model_max_length, + ) + + self.register_buffer("pos_embed", self.get_spatial_pos_embed()) + self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) + + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList( + [ + PixArtBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + drop_path=drop_path[i], + enable_flash_attn=enable_flash_attn, + enable_layernorm_kernel=enable_layernorm_kernel, + ) + for i in range(depth) + ] + ) + self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels) + + self.initialize_weights() + if freeze is not None: + assert freeze in ["text"] + if freeze == "text": + self.freeze_text() + + def forward(self, x, timestep, y, mask=None): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + + # embedding + x = self.x_embedder(x) # (B, N, D) + x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) + x = x + self.pos_embed + if not self.no_temporal_pos_emb: + x = rearrange(x, "b t s d -> b s t d") + x = x + self.pos_embed_temporal + x = rearrange(x, "b s t d -> b (t s) d") + else: + x = rearrange(x, "b t s d -> b (t s) d") + + t = self.t_embedder(timestep, dtype=x.dtype) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # blocks + for block in self.blocks: + x = auto_grad_checkpoint(block, x, y, t0, y_lens) + + # final process + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def unpatchify(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def get_spatial_pos_embed(self, grid_size=None): + if grid_size is None: + grid_size = self.input_size[1:] + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + (grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]), + scale=self.space_scale, + base_size=self.base_size, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def get_temporal_pos_embed(self): + pos_embed = get_1d_sincos_pos_embed( + self.hidden_size, + self.input_size[0] // self.patch_size[0], + scale=self.time_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def freeze_text(self): + for n, p in self.named_parameters(): + if "cross_attn" in n: + p.requires_grad = False + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +@MODELS.register_module() +class PixArtMS(PixArt): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3" + self.csize_embedder = SizeEmbedder(self.hidden_size // 3) + self.ar_embedder = SizeEmbedder(self.hidden_size // 3) + + def forward(self, x, timestep, y, mask=None, data_info=None): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + + c_size = data_info["hw"] + ar = data_info["ar"] + pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype) + + # embedding + x = self.x_embedder(x) # (B, N, D) + x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) + x = x + pos_embed.to(x.device) + if not self.no_temporal_pos_emb: + x = rearrange(x, "b t s d -> b s t d") + x = x + self.pos_embed_temporal + x = rearrange(x, "b s t d -> b (t s) d") + else: + x = rearrange(x, "b t s d -> b (t s) d") + + t = self.t_embedder(timestep, dtype=x.dtype) # (N, D) + B = x.shape[0] + csize = self.csize_embedder(c_size, B) + ar = self.ar_embedder(ar, B) + t = t + torch.cat([csize, ar], dim=1) + + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # blocks + for block in self.blocks: + x = block(x, y, t0, y_lens) + + # final process + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + +@MODELS.register_module("PixArt-XL/2") +def PixArt_XL_2(from_pretrained=None, **kwargs): + model = PixArt(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model + + +@MODELS.register_module("PixArt-1B/2") +def PixArt_1B_2(from_pretrained=None, **kwargs): + model = PixArt(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model + + +@MODELS.register_module("PixArtMS-XL/2") +def PixArtMS_XL_2(from_pretrained=None, **kwargs): + model = PixArtMS(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/pixart/pixart_sigma.py b/src/videogen_hub/pipelines/opensora/opensora/models/pixart/pixart_sigma.py new file mode 100644 index 0000000000000000000000000000000000000000..ae53d1671bac4fc1ef9038161c1451492ecf6c97 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/pixart/pixart_sigma.py @@ -0,0 +1,342 @@ +# Adapted from PixArt +# +# Copyright (C) 2023 PixArt-alpha/PixArt-alpha +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# DiT: https://github.com/facebookresearch/DiT/tree/main +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + +# from .builder import MODELS +from videogen_hub.pipelines.opensora.opensora.acceleration.checkpoint import auto_grad_checkpoint +from videogen_hub.pipelines.opensora.opensora.models.layers.blocks import ( + CaptionEmbedder, + KVCompressAttention, + MultiHeadCrossAttention, + PatchEmbed3D, + T2IFinalLayer, + TimestepEmbedder, + approx_gelu, + get_1d_sincos_pos_embed, + get_2d_sincos_pos_embed, + get_layernorm, + t2i_modulate, +) +from videogen_hub.pipelines.opensora.opensora.registry import MODELS +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import load_checkpoint + + +class PixArtBlock(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + drop_path=0.0, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + qk_norm=False, + sampling="conv", + sr_ratio=1, + ): + super().__init__() + self.hidden_size = hidden_size + self.enable_flash_attn = enable_flash_attn + self._enable_sequence_parallelism = enable_sequence_parallelism + assert not enable_sequence_parallelism, "Sequence parallelism is not supported in this version." + + self.attn_cls = KVCompressAttention + self.mha_cls = MultiHeadCrossAttention + + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = self.attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flash_attn=enable_flash_attn, + qk_norm=qk_norm, + sr_ratio=sr_ratio, + sampling=sampling, + attn_half=True, + ) + self.cross_attn = self.mha_cls(hidden_size, num_heads) + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + self.sampling = sampling + self.sr_ratio = sr_ratio + + def forward(self, x, y, t, hw, mask=None): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + x = x + self.drop_path( + gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=hw).reshape(B, N, C) + ) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +@MODELS.register_module() +class PixArt_Sigma(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=(1, 32, 32), + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path: float = 0.0, + no_temporal_pos_emb=False, + caption_channels=4096, + model_max_length=120, + dtype=torch.float32, + freeze=None, + qk_norm=False, + space_scale=1.0, + time_scale=1.0, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + kv_compress_config=None, + ): + super().__init__() + assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in this version." + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.input_size = input_size + num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) + self.num_patches = num_patches + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = num_patches // self.num_temporal + self.base_size = int(np.sqrt(self.num_spatial)) + self.num_heads = num_heads + self.dtype = dtype + self.no_temporal_pos_emb = no_temporal_pos_emb + self.depth = depth + self.mlp_ratio = mlp_ratio + self.enable_flash_attn = enable_flash_attn + self.enable_layernorm_kernel = enable_layernorm_kernel + self.space_scale = space_scale + self.time_scale = time_scale + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=model_max_length, + ) + + self.register_buffer("pos_embed", self.get_spatial_pos_embed()) + self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) + + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + + self.kv_compress_config = kv_compress_config + if kv_compress_config is None: + self.kv_compress_config = { + "sampling": None, + "scale_factor": 1, + "kv_compress_layer": [], + } + + self.blocks = nn.ModuleList( + [ + PixArtBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + drop_path=drop_path[i], + enable_flash_attn=enable_flash_attn, + enable_layernorm_kernel=enable_layernorm_kernel, + qk_norm=qk_norm, + sr_ratio=( + int(self.kv_compress_config["scale_factor"]) + if i in self.kv_compress_config["kv_compress_layer"] + else 1 + ), + sampling=self.kv_compress_config["sampling"], + ) + for i in range(depth) + ] + ) + self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels) + + self.initialize_weights() + if freeze is not None: + assert freeze in ["text"] + if freeze == "text": + self.freeze_text() + + def forward(self, x, timestep, y, mask=None): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype) + hw = (x.shape[-2] // self.patch_size[-2], x.shape[-1] // self.patch_size[-1]) + + # embedding + x = self.x_embedder(x) # (B, N, D) + x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) + x = x + pos_embed.to(x.device) + if not self.no_temporal_pos_emb: + x = rearrange(x, "b t s d -> b s t d") + x = x + self.pos_embed_temporal + x = rearrange(x, "b s t d -> b (t s) d") + else: + x = rearrange(x, "b t s d -> b (t s) d") + + t = self.t_embedder(timestep, dtype=x.dtype) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # blocks + for block in self.blocks: + x = auto_grad_checkpoint(block, x, y, t0, hw, y_lens) + + # final process + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def unpatchify(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def get_spatial_pos_embed(self, grid_size=None): + if grid_size is None: + grid_size = self.input_size[1:] + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + (grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]), + scale=self.space_scale, + base_size=self.base_size, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def get_temporal_pos_embed(self): + pos_embed = get_1d_sincos_pos_embed( + self.hidden_size, + self.input_size[0] // self.patch_size[0], + scale=self.time_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def freeze_text(self): + for n, p in self.named_parameters(): + if "cross_attn" in n: + p.requires_grad = False + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +@MODELS.register_module("PixArt-Sigma-XL/2") +def PixArt_Sigma_XL_2(from_pretrained=None, **kwargs): + model = PixArt_Sigma(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/stdit/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/models/stdit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c503428e86857bb9b4448c3ac7edddd4e05761f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/stdit/__init__.py @@ -0,0 +1,3 @@ +from videogen_hub.pipelines.opensora.opensora.models.stdit.stdit import STDiT +from videogen_hub.pipelines.opensora.opensora.models.stdit.stdit2 import STDiT2 +from videogen_hub.pipelines.opensora.opensora.models.stdit.stdit3 import STDiT3 diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/stdit/stdit.py b/src/videogen_hub/pipelines/opensora/opensora/models/stdit/stdit.py new file mode 100644 index 0000000000000000000000000000000000000000..907249befa34a1de320b2bf48cacef09f5c85757 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/stdit/stdit.py @@ -0,0 +1,438 @@ +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from einops import rearrange +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + +from videogen_hub.pipelines.opensora.opensora.acceleration.checkpoint import auto_grad_checkpoint +from videogen_hub.pipelines.opensora.opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward +from videogen_hub.pipelines.opensora.opensora.acceleration.parallel_states import get_sequence_parallel_group +from videogen_hub.pipelines.opensora.opensora.models.layers.blocks import ( + Attention, + CaptionEmbedder, + MultiHeadCrossAttention, + PatchEmbed3D, + SeqParallelAttention, + SeqParallelMultiHeadCrossAttention, + T2IFinalLayer, + TimestepEmbedder, + approx_gelu, + get_1d_sincos_pos_embed, + get_2d_sincos_pos_embed, + get_layernorm, + t2i_modulate, +) +from videogen_hub.pipelines.opensora.opensora.registry import MODELS +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import load_checkpoint + + +class STDiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + d_s=None, + d_t=None, + mlp_ratio=4.0, + drop_path=0.0, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.enable_flash_attn = enable_flash_attn + self._enable_sequence_parallelism = enable_sequence_parallelism + + if enable_sequence_parallelism: + self.attn_cls = SeqParallelAttention + self.mha_cls = SeqParallelMultiHeadCrossAttention + else: + self.attn_cls = Attention + self.mha_cls = MultiHeadCrossAttention + + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = self.attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flash_attn=enable_flash_attn, + ) + self.cross_attn = self.mha_cls(hidden_size, num_heads) + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + # temporal attention + self.d_s = d_s + self.d_t = d_t + + if self._enable_sequence_parallelism: + sp_size = dist.get_world_size(get_sequence_parallel_group()) + # make sure d_t is divisible by sp_size + assert d_t % sp_size == 0 + self.d_t = d_t // sp_size + + self.attn_temp = self.attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flash_attn=self.enable_flash_attn, + ) + + def t_mask_select(self, x, masked_x, x_mask): + # x: [B, (T, S), C] + # mased_x: [B, (T, S), C] + # x_mask: [B, T] + x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + x = torch.where(x_mask[:, :, None, None], x, masked_x) + x = rearrange(x, "B T S C -> B (T S) C") + return x + + def forward(self, x, y, t, mask=None, tpe=None, x_mask=None, t0=None): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa) + if x_mask is not None: + shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = ( + self.scale_shift_table[None] + t0.reshape(B, 6, -1) + ).chunk(6, dim=1) + x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero) + x_m = self.t_mask_select(x_m, x_m_zero, x_mask) + + # spatial branch + x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s) + x_s = self.attn(x_s) + x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s) + + if x_mask is not None: + x_s_zero = gate_msa_zero * x_s + x_s = gate_msa * x_s + x_s = self.t_mask_select(x_s, x_s_zero, x_mask) + else: + x_s = gate_msa * x_s + + x = x + self.drop_path(x_s) + + # temporal branch + x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s) + if tpe is not None: + x_t = x_t + tpe + x_t = self.attn_temp(x_t) + x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s) + x = x + self.drop_path(gate_msa * x_t) + + # cross attn + x = x + self.cross_attn(x, y, mask) + + # mlp + x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero) + x_m = self.t_mask_select(x_m, x_m_zero, x_mask) + + x_mlp = self.mlp(x_m) + if x_mask is not None: + x_mlp_zero = gate_mlp_zero * x_mlp + x_mlp = gate_mlp * x_mlp + x_mlp = self.t_mask_select(x_mlp, x_mlp_zero, x_mask) + else: + x_mlp = gate_mlp * x_mlp + + x = x + self.drop_path(x_mlp) + + return x + + +@MODELS.register_module() +class STDiT(nn.Module): + def __init__( + self, + input_size=(1, 32, 32), + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path=0.0, + no_temporal_pos_emb=False, + caption_channels=4096, + model_max_length=120, + dtype=torch.float32, + space_scale=1.0, + time_scale=1.0, + freeze=None, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + ): + super().__init__() + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.input_size = input_size + num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) + self.num_patches = num_patches + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = num_patches // self.num_temporal + self.num_heads = num_heads + self.dtype = dtype + self.no_temporal_pos_emb = no_temporal_pos_emb + self.depth = depth + self.mlp_ratio = mlp_ratio + self.enable_flash_attn = enable_flash_attn + self.enable_layernorm_kernel = enable_layernorm_kernel + self.space_scale = space_scale + self.time_scale = time_scale + + self.register_buffer("pos_embed", self.get_spatial_pos_embed()) + self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=model_max_length, + ) + + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] + self.blocks = nn.ModuleList( + [ + STDiTBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=self.mlp_ratio, + drop_path=drop_path[i], + enable_flash_attn=self.enable_flash_attn, + enable_layernorm_kernel=self.enable_layernorm_kernel, + enable_sequence_parallelism=enable_sequence_parallelism, + d_t=self.num_temporal, + d_s=self.num_spatial, + ) + for i in range(self.depth) + ] + ) + self.final_layer = T2IFinalLayer( + hidden_size, + np.prod(self.patch_size), + self.out_channels, + d_t=self.num_temporal, + d_s=self.num_spatial, + ) + + # init model + self.initialize_weights() + self.initialize_temporal() + if freeze is not None: + assert freeze in ["not_temporal", "text"] + if freeze == "not_temporal": + self.freeze_not_temporal() + elif freeze == "text": + self.freeze_text() + + # sequence parallel related configs + self.enable_sequence_parallelism = enable_sequence_parallelism + if enable_sequence_parallelism: + self.sp_rank = dist.get_rank(get_sequence_parallel_group()) + else: + self.sp_rank = None + + def forward(self, x, timestep, y, mask=None, x_mask=None): + """ + Forward pass of STDiT. + Args: + x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W] + timestep (torch.Tensor): diffusion time steps; of shape [B] + y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C] + mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token] + + Returns: + x (torch.Tensor): output latent representation; of shape [B, C, T, H, W] + """ + dtype = self.x_embedder.proj.weight.dtype + x = x.to(dtype) + timestep = timestep.to(dtype) + y = y.to(dtype) + + # embedding + x = self.x_embedder(x) # [B, N, C] + x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial) + x = x + self.pos_embed + x = rearrange(x, "B T S C -> B (T S) C") + + # shard over the sequence dim if sp is enabled + if self.enable_sequence_parallelism: + x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down") + + t = self.t_embedder(timestep, dtype=x.dtype) # [B, C] + t_mlp = self.t_block(t) # [B, C] + if x_mask is not None: + t0_timestep = torch.zeros_like(timestep) + t0 = self.t_embedder(t0_timestep, dtype=x.dtype) + t0_mlp = self.t_block(t0) + else: + t0 = None + t0_mlp = None + y = self.y_embedder(y, self.training) # [B, 1, N_token, C] + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # blocks + for i, block in enumerate(self.blocks): + if i == 0: + if self.enable_sequence_parallelism: + tpe = torch.chunk( + self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1 + )[self.sp_rank].contiguous() + else: + tpe = self.pos_embed_temporal + else: + tpe = None + x = auto_grad_checkpoint(block, x, y, t_mlp, y_lens, tpe, x_mask, t0_mlp) + + if self.enable_sequence_parallelism: + x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up") + # x.shape: [B, N, C] + + # final process + x = self.final_layer(x, t, x_mask, t0) # [B, N, C=T_p * H_p * W_p * C_out] + x = self.unpatchify(x) # [B, C_out, T, H, W] + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def unpatchify(self, x): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + + N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + return x + + def unpatchify_old(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def get_spatial_pos_embed(self, grid_size=None): + if grid_size is None: + grid_size = self.input_size[1:] + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + (grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]), + scale=self.space_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def get_temporal_pos_embed(self): + pos_embed = get_1d_sincos_pos_embed( + self.hidden_size, + self.input_size[0] // self.patch_size[0], + scale=self.time_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def freeze_not_temporal(self): + for n, p in self.named_parameters(): + if "attn_temp" not in n: + p.requires_grad = False + + def freeze_text(self): + for n, p in self.named_parameters(): + if "cross_attn" in n: + p.requires_grad = False + + def initialize_temporal(self): + for block in self.blocks: + nn.init.constant_(block.attn_temp.proj.weight, 0) + nn.init.constant_(block.attn_temp.proj.bias, 0) + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +@MODELS.register_module("STDiT-XL/2") +def STDiT_XL_2(from_pretrained=None, **kwargs): + model = STDiT(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/stdit/stdit2.py b/src/videogen_hub/pipelines/opensora/opensora/models/stdit/stdit2.py new file mode 100644 index 0000000000000000000000000000000000000000..0a0b2a8f7fbf439addce2227b2cfb18aaaae4cdb --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/stdit/stdit2.py @@ -0,0 +1,524 @@ +import os + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from rotary_embedding_torch import RotaryEmbedding +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp +from transformers import PretrainedConfig, PreTrainedModel + +from videogen_hub.pipelines.opensora.opensora.acceleration.checkpoint import auto_grad_checkpoint +from videogen_hub.pipelines.opensora.opensora.models.layers.blocks import ( + Attention, + CaptionEmbedder, + MultiHeadCrossAttention, + PatchEmbed3D, + PositionEmbedding2D, + SizeEmbedder, + T2IFinalLayer, + TimestepEmbedder, + approx_gelu, + get_2d_sincos_pos_embed, + get_layernorm, + t2i_modulate, +) +from videogen_hub.pipelines.opensora.opensora.registry import MODELS +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import load_checkpoint + + +class STDiT2Block(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + drop_path=0.0, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + rope=None, + qk_norm=False, + qk_norm_legacy=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.enable_flash_attn = enable_flash_attn + self._enable_sequence_parallelism = enable_sequence_parallelism + + # spatial branch + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = Attention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flash_attn=enable_flash_attn, + qk_norm=qk_norm, + qk_norm_legacy=qk_norm_legacy, + ) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + # cross attn + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads) + + # mlp branch + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # temporal branch + self.norm_temp = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) # new + self.attn_temp = Attention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flash_attn=self.enable_flash_attn, + rope=rope, + qk_norm=qk_norm, + qk_norm_legacy=qk_norm_legacy, + ) + self.scale_shift_table_temporal = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5) # new + + def t_mask_select(self, x_mask, x, masked_x, T, S): + # x: [B, (T, S), C] + # mased_x: [B, (T, S), C] + # x_mask: [B, T] + x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) + masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S) + x = torch.where(x_mask[:, :, None, None], x, masked_x) + x = rearrange(x, "B T S C -> B (T S) C") + return x + + def forward(self, x, y, t, t_tmp, mask=None, x_mask=None, t0=None, t0_tmp=None, T=None, S=None): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + shift_tmp, scale_tmp, gate_tmp = (self.scale_shift_table_temporal[None] + t_tmp.reshape(B, 3, -1)).chunk( + 3, dim=1 + ) + if x_mask is not None: + shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = ( + self.scale_shift_table[None] + t0.reshape(B, 6, -1) + ).chunk(6, dim=1) + shift_tmp_zero, scale_tmp_zero, gate_tmp_zero = ( + self.scale_shift_table_temporal[None] + t0_tmp.reshape(B, 3, -1) + ).chunk(3, dim=1) + + # modulate + x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S) + + # spatial branch + x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S) + x_s = self.attn(x_s) + x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=T, S=S) + if x_mask is not None: + x_s_zero = gate_msa_zero * x_s + x_s = gate_msa * x_s + x_s = self.t_mask_select(x_mask, x_s, x_s_zero, T, S) + else: + x_s = gate_msa * x_s + x = x + self.drop_path(x_s) + + # modulate + x_m = t2i_modulate(self.norm_temp(x), shift_tmp, scale_tmp) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm_temp(x), shift_tmp_zero, scale_tmp_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S) + + # temporal branch + x_t = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S) + x_t = self.attn_temp(x_t) + x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=T, S=S) + if x_mask is not None: + x_t_zero = gate_tmp_zero * x_t + x_t = gate_tmp * x_t + x_t = self.t_mask_select(x_mask, x_t, x_t_zero, T, S) + else: + x_t = gate_tmp * x_t + x = x + self.drop_path(x_t) + + # cross attn + x = x + self.cross_attn(x, y, mask) + + # modulate + x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S) + + # mlp + x_mlp = self.mlp(x_m) + if x_mask is not None: + x_mlp_zero = gate_mlp_zero * x_mlp + x_mlp = gate_mlp * x_mlp + x_mlp = self.t_mask_select(x_mask, x_mlp, x_mlp_zero, T, S) + else: + x_mlp = gate_mlp * x_mlp + x = x + self.drop_path(x_mlp) + + return x + + +class STDiT2Config(PretrainedConfig): + model_type = "STDiT2" + + def __init__( + self, + input_size=(None, None, None), + input_sq_size=32, + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path=0.0, + no_temporal_pos_emb=False, + caption_channels=4096, + model_max_length=120, + freeze=None, + qk_norm=False, + qk_norm_legacy=False, + enable_flash_attn=False, + enable_layernorm_kernel=False, + **kwargs, + ): + self.input_size = input_size + self.input_sq_size = input_sq_size + self.in_channels = in_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + self.depth = depth + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.class_dropout_prob = class_dropout_prob + self.pred_sigma = pred_sigma + self.drop_path = drop_path + self.no_temporal_pos_emb = no_temporal_pos_emb + self.caption_channels = caption_channels + self.model_max_length = model_max_length + self.freeze = freeze + self.qk_norm = qk_norm + self.qk_norm_legacy = qk_norm_legacy + self.enable_flash_attn = enable_flash_attn + self.enable_layernorm_kernel = enable_layernorm_kernel + super().__init__(**kwargs) + + +@MODELS.register_module() +class STDiT2(PreTrainedModel): + config_class = STDiT2Config + + def __init__(self, config): + super().__init__(config) + self.pred_sigma = config.pred_sigma + self.in_channels = config.in_channels + self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + self.no_temporal_pos_emb = config.no_temporal_pos_emb + self.depth = config.depth + self.mlp_ratio = config.mlp_ratio + self.enable_flash_attn = config.enable_flash_attn + self.enable_layernorm_kernel = config.enable_layernorm_kernel + + # support dynamic input + self.patch_size = config.patch_size + self.input_size = config.input_size + self.input_sq_size = config.input_sq_size + self.pos_embed = PositionEmbedding2D(config.hidden_size) + + self.x_embedder = PatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size) + self.t_embedder = TimestepEmbedder(config.hidden_size) + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True)) + self.t_block_temp = nn.Sequential( + nn.SiLU(), nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True) + ) # new + self.y_embedder = CaptionEmbedder( + in_channels=config.caption_channels, + hidden_size=config.hidden_size, + uncond_prob=config.class_dropout_prob, + act_layer=approx_gelu, + token_num=config.model_max_length, + ) + + drop_path = [x.item() for x in torch.linspace(0, config.drop_path, config.depth)] + self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads) # new + self.blocks = nn.ModuleList( + [ + STDiT2Block( + self.hidden_size, + self.num_heads, + mlp_ratio=self.mlp_ratio, + drop_path=drop_path[i], + enable_flash_attn=self.enable_flash_attn, + enable_layernorm_kernel=self.enable_layernorm_kernel, + rope=self.rope.rotate_queries_or_keys, + qk_norm=config.qk_norm, + qk_norm_legacy=config.qk_norm_legacy, + ) + for i in range(self.depth) + ] + ) + self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels) + + # multi_res + assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3" + self.csize_embedder = SizeEmbedder(self.hidden_size // 3) + self.ar_embedder = SizeEmbedder(self.hidden_size // 3) + self.fl_embedder = SizeEmbedder(self.hidden_size) # new + self.fps_embedder = SizeEmbedder(self.hidden_size) # new + + # init model + self.initialize_weights() + self.initialize_temporal() + if config.freeze is not None: + assert config.freeze in ["not_temporal", "text"] + if config.freeze == "not_temporal": + self.freeze_not_temporal() + elif config.freeze == "text": + self.freeze_text() + + def get_dynamic_size(self, x): + _, _, T, H, W = x.size() + if T % self.patch_size[0] != 0: + T += self.patch_size[0] - T % self.patch_size[0] + if H % self.patch_size[1] != 0: + H += self.patch_size[1] - H % self.patch_size[1] + if W % self.patch_size[2] != 0: + W += self.patch_size[2] - W % self.patch_size[2] + T = T // self.patch_size[0] + H = H // self.patch_size[1] + W = W // self.patch_size[2] + return (T, H, W) + + def forward( + self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None, fps=None + ): + """ + Forward pass of STDiT. + Args: + x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W] + timestep (torch.Tensor): diffusion time steps; of shape [B] + y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C] + mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token] + + Returns: + x (torch.Tensor): output latent representation; of shape [B, C, T, H, W] + """ + B = x.shape[0] + dtype = self.x_embedder.proj.weight.dtype + x = x.to(dtype) + timestep = timestep.to(dtype) + y = y.to(dtype) + + # === process data info === + # 1. get dynamic size + hw = torch.cat([height[:, None], width[:, None]], dim=1) + rs = (height[0].item() * width[0].item()) ** 0.5 + csize = self.csize_embedder(hw, B) + + # 2. get aspect ratio + ar = ar.unsqueeze(1) + ar = self.ar_embedder(ar, B) + data_info = torch.cat([csize, ar], dim=1) + + # 3. get number of frames + fl = num_frames.unsqueeze(1) + fps = fps.unsqueeze(1) + fl = self.fl_embedder(fl, B) + fl = fl + self.fps_embedder(fps, B) + + # === get dynamic shape size === + _, _, Tx, Hx, Wx = x.size() + T, H, W = self.get_dynamic_size(x) + S = H * W + scale = rs / self.input_sq_size + base_size = round(S**0.5) + pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size) + + # embedding + x = self.x_embedder(x) # [B, N, C] + x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) + x = x + pos_emb + x = rearrange(x, "B T S C -> B (T S) C") + + # prepare adaIN + t = self.t_embedder(timestep, dtype=x.dtype) # [B, C] + t_spc = t + data_info # [B, C] + t_tmp = t + fl # [B, C] + t_spc_mlp = self.t_block(t_spc) # [B, 6*C] + t_tmp_mlp = self.t_block_temp(t_tmp) # [B, 3*C] + if x_mask is not None: + t0_timestep = torch.zeros_like(timestep) + t0 = self.t_embedder(t0_timestep, dtype=x.dtype) + t0_spc = t0 + data_info + t0_tmp = t0 + fl + t0_spc_mlp = self.t_block(t0_spc) + t0_tmp_mlp = self.t_block_temp(t0_tmp) + else: + t0_spc = None + t0_tmp = None + t0_spc_mlp = None + t0_tmp_mlp = None + + # prepare y + y = self.y_embedder(y, self.training) # [B, 1, N_token, C] + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # blocks + for _, block in enumerate(self.blocks): + x = auto_grad_checkpoint( + block, + x, + y, + t_spc_mlp, + t_tmp_mlp, + y_lens, + x_mask, + t0_spc_mlp, + t0_tmp_mlp, + T, + S, + ) + # x.shape: [B, N, C] + + # final process + x = self.final_layer(x, t, x_mask, t0_spc, T, S) # [B, N, C=T_p * H_p * W_p * C_out] + x = self.unpatchify(x, T, H, W, Tx, Hx, Wx) # [B, C_out, T, H, W] + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + + # N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + # unpad + x = x[:, :, :R_t, :R_h, :R_w] + return x + + def unpatchify_old(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def get_spatial_pos_embed(self, H, W, scale=1.0, base_size=None): + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + (H, W), + scale=scale, + base_size=base_size, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def freeze_not_temporal(self): + for n, p in self.named_parameters(): + if "attn_temp" not in n: + p.requires_grad = False + + def freeze_text(self): + for n, p in self.named_parameters(): + if "cross_attn" in n: + p.requires_grad = False + + def initialize_temporal(self): + for block in self.blocks: + nn.init.constant_(block.attn_temp.proj.weight, 0) + nn.init.constant_(block.attn_temp.proj.bias, 0) + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + nn.init.normal_(self.t_block_temp[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +@MODELS.register_module("STDiT2-XL/2") +def STDiT2_XL_2(from_pretrained=None, **kwargs): + if from_pretrained is not None: + if os.path.isdir(from_pretrained) or os.path.isfile(from_pretrained): + # if it is a directory or a file, we load the checkpoint manually + config = STDiT2Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + model = STDiT2(config) + load_checkpoint(model, from_pretrained) + return model + else: + # otherwise, we load the model from hugging face hub + return STDiT2.from_pretrained(from_pretrained) + else: + # create a new model + config = STDiT2Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + model = STDiT2(config) + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/stdit/stdit3.py b/src/videogen_hub/pipelines/opensora/opensora/models/stdit/stdit3.py new file mode 100644 index 0000000000000000000000000000000000000000..6824714ba7308f54e39cf1617f42c8100df42f5a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/stdit/stdit3.py @@ -0,0 +1,471 @@ +import os + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from einops import rearrange +from rotary_embedding_torch import RotaryEmbedding +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp +from transformers import PretrainedConfig, PreTrainedModel + +from videogen_hub.pipelines.opensora.opensora.acceleration.checkpoint import auto_grad_checkpoint +from videogen_hub.pipelines.opensora.opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward +from videogen_hub.pipelines.opensora.opensora.acceleration.parallel_states import get_sequence_parallel_group +from videogen_hub.pipelines.opensora.opensora.models.layers.blocks import ( + Attention, + CaptionEmbedder, + MultiHeadCrossAttention, + PatchEmbed3D, + PositionEmbedding2D, + SeqParallelAttention, + SeqParallelMultiHeadCrossAttention, + SizeEmbedder, + T2IFinalLayer, + TimestepEmbedder, + approx_gelu, + get_layernorm, + t2i_modulate, +) +from videogen_hub.pipelines.opensora.opensora.registry import MODELS +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import load_checkpoint + + +class STDiT3Block(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + drop_path=0.0, + rope=None, + qk_norm=False, + temporal=False, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + ): + super().__init__() + self.temporal = temporal + self.hidden_size = hidden_size + self.enable_flash_attn = enable_flash_attn + self.enable_sequence_parallelism = enable_sequence_parallelism + + if self.enable_sequence_parallelism and not temporal: + attn_cls = SeqParallelAttention + mha_cls = SeqParallelMultiHeadCrossAttention + else: + attn_cls = Attention + mha_cls = MultiHeadCrossAttention + + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=qk_norm, + rope=rope, + enable_flash_attn=enable_flash_attn, + ) + self.cross_attn = mha_cls(hidden_size, num_heads) + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def t_mask_select(self, x_mask, x, masked_x, T, S): + # x: [B, (T, S), C] + # mased_x: [B, (T, S), C] + # x_mask: [B, T] + x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) + masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S) + x = torch.where(x_mask[:, :, None, None], x, masked_x) + x = rearrange(x, "B T S C -> B (T S) C") + return x + + def forward( + self, + x, + y, + t, + mask=None, # text mask + x_mask=None, # temporal mask + t0=None, # t with timestamp=0 + T=None, # number of frames + S=None, # number of pixel patches + ): + # prepare modulate parameters + B, N, C = x.shape + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + if x_mask is not None: + shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = ( + self.scale_shift_table[None] + t0.reshape(B, 6, -1) + ).chunk(6, dim=1) + + # modulate (attention) + x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S) + + # attention + if self.temporal: + x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S) + x_m = self.attn(x_m) + x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S) + else: + x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S) + x_m = self.attn(x_m) + x_m = rearrange(x_m, "(B T) S C -> B (T S) C", T=T, S=S) + + # modulate (attention) + x_m_s = gate_msa * x_m + if x_mask is not None: + x_m_s_zero = gate_msa_zero * x_m + x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S) + + # residual + x = x + self.drop_path(x_m_s) + + # cross attention + x = x + self.cross_attn(x, y, mask) + + # modulate (MLP) + x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S) + + # MLP + x_m = self.mlp(x_m) + + # modulate (MLP) + x_m_s = gate_mlp * x_m + if x_mask is not None: + x_m_s_zero = gate_mlp_zero * x_m + x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S) + + # residual + x = x + self.drop_path(x_m_s) + + return x + + +class STDiT3Config(PretrainedConfig): + model_type = "STDiT3" + + def __init__( + self, + input_size=(None, None, None), + input_sq_size=512, + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path=0.0, + caption_channels=4096, + model_max_length=300, + qk_norm=True, + enable_flash_attn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + only_train_temporal=False, + freeze_y_embedder=False, + skip_y_embedder=False, + **kwargs, + ): + self.input_size = input_size + self.input_sq_size = input_sq_size + self.in_channels = in_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + self.depth = depth + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.class_dropout_prob = class_dropout_prob + self.pred_sigma = pred_sigma + self.drop_path = drop_path + self.caption_channels = caption_channels + self.model_max_length = model_max_length + self.qk_norm = qk_norm + self.enable_flash_attn = enable_flash_attn + self.enable_layernorm_kernel = enable_layernorm_kernel + self.enable_sequence_parallelism = enable_sequence_parallelism + self.only_train_temporal = only_train_temporal + self.freeze_y_embedder = freeze_y_embedder + self.skip_y_embedder = skip_y_embedder + super().__init__(**kwargs) + + +class STDiT3(PreTrainedModel): + config_class = STDiT3Config + + def __init__(self, config): + super().__init__(config) + self.pred_sigma = config.pred_sigma + self.in_channels = config.in_channels + self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels + + # model size related + self.depth = config.depth + self.mlp_ratio = config.mlp_ratio + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + + # computation related + self.drop_path = config.drop_path + self.enable_flash_attn = config.enable_flash_attn + self.enable_layernorm_kernel = config.enable_layernorm_kernel + self.enable_sequence_parallelism = config.enable_sequence_parallelism + + # input size related + self.patch_size = config.patch_size + self.input_sq_size = config.input_sq_size + self.pos_embed = PositionEmbedding2D(config.hidden_size) + self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads) + + # embedding + self.x_embedder = PatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size) + self.t_embedder = TimestepEmbedder(config.hidden_size) + self.fps_embedder = SizeEmbedder(self.hidden_size) + self.t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True), + ) + self.y_embedder = CaptionEmbedder( + in_channels=config.caption_channels, + hidden_size=config.hidden_size, + uncond_prob=config.class_dropout_prob, + act_layer=approx_gelu, + token_num=config.model_max_length, + ) + + # spatial blocks + drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)] + self.spatial_blocks = nn.ModuleList( + [ + STDiT3Block( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + mlp_ratio=config.mlp_ratio, + drop_path=drop_path[i], + qk_norm=config.qk_norm, + enable_flash_attn=config.enable_flash_attn, + enable_layernorm_kernel=config.enable_layernorm_kernel, + enable_sequence_parallelism=config.enable_sequence_parallelism, + ) + for i in range(config.depth) + ] + ) + + # temporal blocks + drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)] + self.temporal_blocks = nn.ModuleList( + [ + STDiT3Block( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + mlp_ratio=config.mlp_ratio, + drop_path=drop_path[i], + qk_norm=config.qk_norm, + enable_flash_attn=config.enable_flash_attn, + enable_layernorm_kernel=config.enable_layernorm_kernel, + enable_sequence_parallelism=config.enable_sequence_parallelism, + # temporal + temporal=True, + rope=self.rope.rotate_queries_or_keys, + ) + for i in range(config.depth) + ] + ) + + # final layer + self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels) + + self.initialize_weights() + if config.only_train_temporal: + for param in self.parameters(): + param.requires_grad = False + for block in self.temporal_blocks: + for param in block.parameters(): + param.requires_grad = True + + if config.freeze_y_embedder: + for param in self.y_embedder.parameters(): + param.requires_grad = False + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize fps_embedder + nn.init.normal_(self.fps_embedder.mlp[0].weight, std=0.02) + nn.init.constant_(self.fps_embedder.mlp[0].bias, 0) + nn.init.constant_(self.fps_embedder.mlp[2].weight, 0) + nn.init.constant_(self.fps_embedder.mlp[2].bias, 0) + + # Initialize timporal blocks + for block in self.temporal_blocks: + nn.init.constant_(block.attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.mlp.fc2.weight, 0) + + def get_dynamic_size(self, x): + _, _, T, H, W = x.size() + if T % self.patch_size[0] != 0: + T += self.patch_size[0] - T % self.patch_size[0] + if H % self.patch_size[1] != 0: + H += self.patch_size[1] - H % self.patch_size[1] + if W % self.patch_size[2] != 0: + W += self.patch_size[2] - W % self.patch_size[2] + T = T // self.patch_size[0] + H = H // self.patch_size[1] + W = W // self.patch_size[2] + return (T, H, W) + + def encode_text(self, y, mask=None): + y = self.y_embedder(y, self.training) # [B, 1, N_token, C] + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, self.hidden_size) + return y, y_lens + + def forward(self, x, timestep, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs): + dtype = self.x_embedder.proj.weight.dtype + B = x.size(0) + x = x.to(dtype) + timestep = timestep.to(dtype) + y = y.to(dtype) + + # === get pos embed === + _, _, Tx, Hx, Wx = x.size() + T, H, W = self.get_dynamic_size(x) + S = H * W + base_size = round(S**0.5) + resolution_sq = (height[0].item() * width[0].item()) ** 0.5 + scale = resolution_sq / self.input_sq_size + pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size) + + # === get timestep embed === + t = self.t_embedder(timestep, dtype=x.dtype) # [B, C] + fps = self.fps_embedder(fps.unsqueeze(1), B) + t = t + fps + t_mlp = self.t_block(t) + t0 = t0_mlp = None + if x_mask is not None: + t0_timestep = torch.zeros_like(timestep) + t0 = self.t_embedder(t0_timestep, dtype=x.dtype) + t0 = t0 + fps + t0_mlp = self.t_block(t0) + + # === get y embed === + if self.config.skip_y_embedder: + y_lens = mask + if isinstance(y_lens, torch.Tensor): + y_lens = y_lens.long().tolist() + else: + y, y_lens = self.encode_text(y, mask) + + # === get x embed === + x = self.x_embedder(x) # [B, N, C] + x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) + x = x + pos_emb + + # shard over the sequence dim if sp is enabled + if self.enable_sequence_parallelism: + x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=2, grad_scale="down") + S = S // dist.get_world_size(get_sequence_parallel_group()) + + x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) + + # === blocks === + for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks): + x = auto_grad_checkpoint(spatial_block, x, y, t_mlp, y_lens, x_mask, t0_mlp, T, S) + x = auto_grad_checkpoint(temporal_block, x, y, t_mlp, y_lens, x_mask, t0_mlp, T, S) + + if self.enable_sequence_parallelism: + x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) + x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=2, grad_scale="up") + S = S * dist.get_world_size(get_sequence_parallel_group()) + x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) + + # === final layer === + x = self.final_layer(x, t, x_mask, t0, T, S) + x = self.unpatchify(x, T, H, W, Tx, Hx, Wx) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + + # N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + # unpad + x = x[:, :, :R_t, :R_h, :R_w] + return x + + +@MODELS.register_module("STDiT3-XL/2") +def STDiT3_XL_2(from_pretrained=None, **kwargs): + if from_pretrained is not None and not os.path.isdir(from_pretrained): + model = STDiT3.from_pretrained(from_pretrained, **kwargs) + else: + config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + model = STDiT3(config) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model + + +@MODELS.register_module("STDiT3-3B/2") +def STDiT3_3B_2(from_pretrained=None, **kwargs): + # check if from_pretrained is a path + force_huggingface = kwargs.pop("force_huggingface", True) + if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)): + model = STDiT3.from_pretrained(from_pretrained, **kwargs) + else: + config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs) + model = STDiT3(config) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76b7379ab432a0f48d9efa4e538945104a3ef6a1 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/__init__.py @@ -0,0 +1,3 @@ +from videogen_hub.pipelines.opensora.opensora.models.text_encoder.classes import ClassEncoder +from videogen_hub.pipelines.opensora.opensora.models.text_encoder.clip import ClipEncoder +from videogen_hub.pipelines.opensora.opensora.models.text_encoder.t5 import T5Encoder diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/classes.py b/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/classes.py new file mode 100644 index 0000000000000000000000000000000000000000..90ec0bc8645d65b0516c133dbf8036516c75a81e --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/classes.py @@ -0,0 +1,20 @@ +import torch + +from videogen_hub.pipelines.opensora.opensora.registry import MODELS + + +@MODELS.register_module("classes") +class ClassEncoder: + def __init__(self, num_classes, model_max_length=None, device="cuda", dtype=torch.float): + self.num_classes = num_classes + self.y_embedder = None + + self.model_max_length = model_max_length + self.output_dim = None + self.device = device + + def encode(self, text): + return dict(y=torch.tensor([int(t) for t in text]).to(self.device)) + + def null(self, n): + return torch.tensor([self.num_classes] * n).to(self.device) diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/clip.py b/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..b93906ec7f2cdca08c332357580da2a17af683e4 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/clip.py @@ -0,0 +1,114 @@ +# Copyright 2024 Vchitect/Latte +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Modified from Latte +# +# This file is adapted from the Latte project. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Latte: https://github.com/Vchitect/Latte +# DiT: https://github.com/facebookresearch/DiT/tree/main +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +import transformers +from transformers import CLIPTextModel, CLIPTokenizer + +from videogen_hub.pipelines.opensora.opensora.registry import MODELS + +transformers.logging.set_verbosity_error() + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(path) + self.transformer = CLIPTextModel.from_pretrained(path) + self.device = device + self.max_length = max_length + self._freeze() + + def _freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + pooled_z = outputs.pooler_output + return z, pooled_z + + def encode(self, text): + return self(text) + + +@MODELS.register_module("clip") +class ClipEncoder: + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + + def __init__( + self, + from_pretrained, + model_max_length=77, + device="cuda", + dtype=torch.float, + ): + super().__init__() + assert from_pretrained is not None, "Please specify the path to the T5 model" + + self.text_encoder = FrozenCLIPEmbedder(path=from_pretrained, max_length=model_max_length).to(device, dtype) + self.y_embedder = None + + self.model_max_length = model_max_length + self.output_dim = self.text_encoder.transformer.config.hidden_size + + def encode(self, text): + _, pooled_embeddings = self.text_encoder.encode(text) + y = pooled_embeddings.unsqueeze(1).unsqueeze(1) + return dict(y=y) + + def null(self, n): + null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] + return null_y + + def to(self, dtype): + self.text_encoder = self.text_encoder.to(dtype) + return self diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/t5.py b/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..7ffddc02b27b40e0900118031e28bda6b9d96259 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/text_encoder/t5.py @@ -0,0 +1,344 @@ +# Adapted from PixArt +# +# Copyright (C) 2023 PixArt-alpha/PixArt-alpha +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# T5: https://github.com/google-research/text-to-text-transfer-transformer +# -------------------------------------------------------- + +import html +import re + +import ftfy +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from videogen_hub.pipelines.opensora.opensora.registry import MODELS + + +class T5Embedder: + available_models = ["DeepFloyd/t5-v1_1-xxl"] + + def __init__( + self, + device, + from_pretrained=None, + *, + cache_dir=None, + hf_token=None, + use_text_preprocessing=True, + t5_model_kwargs=None, + torch_dtype=None, + use_offload_folder=None, + model_max_length=120, + local_files_only=False, + ): + self.device = torch.device(device) + self.torch_dtype = torch_dtype or torch.bfloat16 + self.cache_dir = cache_dir + + if t5_model_kwargs is None: + t5_model_kwargs = { + "low_cpu_mem_usage": True, + "torch_dtype": self.torch_dtype, + } + + if use_offload_folder is not None: + t5_model_kwargs["offload_folder"] = use_offload_folder + t5_model_kwargs["device_map"] = { + "shared": self.device, + "encoder.embed_tokens": self.device, + "encoder.block.0": self.device, + "encoder.block.1": self.device, + "encoder.block.2": self.device, + "encoder.block.3": self.device, + "encoder.block.4": self.device, + "encoder.block.5": self.device, + "encoder.block.6": self.device, + "encoder.block.7": self.device, + "encoder.block.8": self.device, + "encoder.block.9": self.device, + "encoder.block.10": self.device, + "encoder.block.11": self.device, + "encoder.block.12": "disk", + "encoder.block.13": "disk", + "encoder.block.14": "disk", + "encoder.block.15": "disk", + "encoder.block.16": "disk", + "encoder.block.17": "disk", + "encoder.block.18": "disk", + "encoder.block.19": "disk", + "encoder.block.20": "disk", + "encoder.block.21": "disk", + "encoder.block.22": "disk", + "encoder.block.23": "disk", + "encoder.final_layer_norm": "disk", + "encoder.dropout": "disk", + } + else: + t5_model_kwargs["device_map"] = { + "shared": self.device, + "encoder": self.device, + } + + self.use_text_preprocessing = use_text_preprocessing + self.hf_token = hf_token + + assert from_pretrained in self.available_models + self.tokenizer = AutoTokenizer.from_pretrained( + from_pretrained, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + self.model = T5EncoderModel.from_pretrained( + from_pretrained, + cache_dir=cache_dir, + local_files_only=local_files_only, + **t5_model_kwargs, + ).eval() + self.model_max_length = model_max_length + + def get_text_embeddings(self, texts): + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.model_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + input_ids = text_tokens_and_mask["input_ids"].to(self.device) + attention_mask = text_tokens_and_mask["attention_mask"].to(self.device) + with torch.no_grad(): + text_encoder_embs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + )["last_hidden_state"].detach() + return text_encoder_embs, attention_mask + + +@MODELS.register_module("t5") +class T5Encoder: + def __init__( + self, + from_pretrained=None, + model_max_length=120, + device="cuda", + dtype=torch.float, + cache_dir=None, + shardformer=False, + local_files_only=False, + ): + assert from_pretrained is not None, "Please specify the path to the T5 model" + + self.t5 = T5Embedder( + device=device, + torch_dtype=dtype, + from_pretrained=from_pretrained, + cache_dir=cache_dir, + model_max_length=model_max_length, + local_files_only=local_files_only, + ) + self.t5.model.to(dtype=dtype) + self.y_embedder = None + + self.model_max_length = model_max_length + self.output_dim = self.t5.model.config.d_model + self.dtype = dtype + try: + import colossalai + except: + shardformer = False + + if shardformer: + self.shardformer_t5() + + def shardformer_t5(self): + try: + from colossalai.shardformer import ShardConfig, ShardFormer + + from ...acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy + from ...utils.misc import requires_grad + except: + return + shard_config = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=None, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_flash_attention=False, + enable_jit_fused=True, + enable_sequence_parallelism=False, + enable_sequence_overlap=False, + ) + shard_former = ShardFormer(shard_config=shard_config) + optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy()) + self.t5.model = optim_model.to(self.dtype) + + # ensure the weights are frozen + requires_grad(self.t5.model, False) + + def encode(self, text): + caption_embs, emb_masks = self.t5.get_text_embeddings(text) + caption_embs = caption_embs[:, None] + return dict(y=caption_embs, mask=emb_masks) + + def null(self, n): + null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] + return null_y + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +BAD_PUNCT_REGEX = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" +) # noqa + + +def clean_caption(caption): + import urllib.parse as ul + + from bs4 import BeautifulSoup + + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = basic_clean(caption) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + +def text_preprocessing(text, use_text_preprocessing: bool = True): + if use_text_preprocessing: + # The exact text cleaning as was in the training stage: + text = clean_caption(text) + text = clean_caption(text) + return text + else: + return text.lower().strip() diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/vae/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/models/vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..115653291e78caba3022afb8af3c135d4e519645 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/vae/__init__.py @@ -0,0 +1,3 @@ +from videogen_hub.pipelines.opensora.opensora.models.vae.discriminator import DISCRIMINATOR_3D +from videogen_hub.pipelines.opensora.opensora.models.vae.vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder +from videogen_hub.pipelines.opensora.opensora.models.vae.vae_temporal import VAE_Temporal diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/vae/discriminator.py b/src/videogen_hub/pipelines/opensora/opensora/models/vae/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..5881759e6fe1ed1d4f22e0f2ceff7df3d01213f5 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/vae/discriminator.py @@ -0,0 +1,422 @@ +import functools +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from videogen_hub.pipelines.opensora.opensora.registry import MODELS +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import find_model, load_checkpoint + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def xavier_uniform_weight_init(m): + if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain("relu")) + if m.bias is not None: + nn.init.zeros_(m.bias) + # print("initialized module to xavier_uniform:", m) + + +# SCH: taken from Open Sora Plan +def n_layer_disc_weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +# SCH: own implementation modified on top of: discriminator with anti-aliased downsampling (blurpool Zhang et al.) +class BlurPool3D(nn.Module): + def __init__( + self, + channels, + pad_type="reflect", + filt_size=3, + stride=2, + pad_off=0, + device="cpu", + dtype=torch.bfloat16, + ): + super(BlurPool3D, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [ + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + ] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride - 1) / 2.0) + self.channels = channels + + if self.filt_size == 1: + a = np.array( + [ + 1.0, + ] + ) + elif self.filt_size == 2: + a = np.array([1.0, 1.0]) + elif self.filt_size == 3: + a = np.array([1.0, 2.0, 1.0]) + elif self.filt_size == 4: + a = np.array([1.0, 3.0, 3.0, 1.0]) + elif self.filt_size == 5: + a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) + elif self.filt_size == 6: + a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) + elif self.filt_size == 7: + a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) + + filt_2d = a[:, None] * a[None, :] + filt_3d = torch.Tensor(a[:, None, None] * filt_2d[None, :, :]).to(device, dtype) + + filt = filt_3d / torch.sum(filt_3d) # SCH: modified to it 3D + self.register_buffer("filt", filt[None, None, :, :, :].repeat((self.channels, 1, 1, 1, 1))) + + self.pad = get_pad_layer(pad_type)(self.pad_sizes) + + def forward(self, inp): + if self.filt_size == 1: + if self.pad_off == 0: + return inp[:, :, :: self.stride, :: self.stride] + else: + return self.pad(inp)[:, :, :: self.stride, :: self.stride] + else: + return F.conv3d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + + +class ResBlockDown(nn.Module): + """3D StyleGAN ResBlock for D.""" + + def __init__( + self, + in_channels, + filters, + activation_fn, + num_groups=32, + device="cpu", + dtype=torch.bfloat16, + ): + super().__init__() + + self.filters = filters + self.activation_fn = activation_fn + + # SCH: NOTE: although paper says conv (X->Y, Y->Y), original code implementation is (X->X, X->Y), we follow code + self.conv1 = nn.Conv3d( + in_channels, in_channels, (3, 3, 3), padding=1, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform + self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype) + + self.blur = BlurPool3D(in_channels, device=device, dtype=dtype) + + self.conv2 = nn.Conv3d( + in_channels, self.filters, (1, 1, 1), bias=False, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform + self.conv3 = nn.Conv3d( + in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform + self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype) + + # self.apply(xavier_uniform_weight_init) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation_fn(x) + + residual = self.blur(residual) + residual = self.conv2(residual) + + x = self.blur(x) + x = self.conv3(x) + x = self.norm2(x) + x = self.activation_fn(x) + out = (residual + x) / math.sqrt(2) + return out + + +@MODELS.register_module() +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, from_pretrained=None): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + + norm_layer = nn.BatchNorm2d + + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + if from_pretrained is not None: + load_checkpoint(self, from_pretrained) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class NLayerDiscriminator3D(nn.Module): + """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" + + def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): + """ + Construct a 3D PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input volumes + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + use_actnorm (bool) -- flag to use actnorm instead of batchnorm + """ + super(NLayerDiscriminator3D, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm3d + else: + raise NotImplementedError("Not implemented.") + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func != nn.BatchNorm3d + else: + use_bias = norm_layer != nn.BatchNorm3d + + kw = 4 + padw = 1 + sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv3d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=(kw, kw, kw), + stride=(1, 2, 2), + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv3d( + ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class StyleGANDiscriminatorBlur(nn.Module): + """StyleGAN Discriminator. + + SCH: NOTE: + this discriminator requries the num_frames to be fixed during training; + in case we pre-train with image then train on video, this disciminator's Linear layer would have to be re-trained! + """ + + def __init__( + self, + image_size=(128, 128), + num_frames=17, + in_channels=3, + filters=128, + channel_multipliers=(2, 4, 4, 4, 4), + num_groups=32, + dtype=torch.bfloat16, + device="cpu", + ): + super().__init__() + + self.dtype = dtype + self.input_size = cast_tuple(image_size, 2) + self.filters = filters + self.activation_fn = nn.LeakyReLU(negative_slope=0.2) + self.channel_multipliers = channel_multipliers + + self.conv1 = nn.Conv3d( + in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform + + prev_filters = self.filters # record in_channels + self.num_blocks = len(self.channel_multipliers) + self.res_block_list = nn.ModuleList([]) + for i in range(self.num_blocks): + filters = self.filters * self.channel_multipliers[i] + self.res_block_list.append( + ResBlockDown(prev_filters, filters, self.activation_fn, device=device, dtype=dtype).apply( + xavier_uniform_weight_init + ) + ) + prev_filters = filters # update in_channels + + self.conv2 = nn.Conv3d( + prev_filters, prev_filters, (3, 3, 3), padding=1, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform + # torch.nn.init.xavier_uniform_(self.conv2.weight) + + self.norm1 = nn.GroupNorm(num_groups, prev_filters, dtype=dtype, device=device) + + scale_factor = 2**self.num_blocks + if num_frames % scale_factor != 0: # SCH: NOTE: has first frame which would be padded before usage + time_scaled = num_frames // scale_factor + 1 + else: + time_scaled = num_frames / scale_factor + + assert ( + self.input_size[0] % scale_factor == 0 + ), f"image width {self.input_size[0]} is not divisible by scale factor {scale_factor}" + assert ( + self.input_size[1] % scale_factor == 0 + ), f"image height {self.input_size[1]} is not divisible by scale factor {scale_factor}" + w_scaled, h_scaled = self.input_size[0] / scale_factor, self.input_size[1] / scale_factor + in_features = int(prev_filters * time_scaled * w_scaled * h_scaled) # (C*T*W*H) + self.linear1 = nn.Linear(in_features, prev_filters, device=device, dtype=dtype) # NOTE: init to xavier_uniform + self.linear2 = nn.Linear(prev_filters, 1, device=device, dtype=dtype) # NOTE: init to xavier_uniform + + # self.apply(xavier_uniform_weight_init) + + def forward(self, x): + x = self.conv1(x) + # print("discriminator aft conv:", x.size()) + x = self.activation_fn(x) + + for i in range(self.num_blocks): + x = self.res_block_list[i](x) + # print("discriminator resblock down:", x.size()) + + x = self.conv2(x) + # print("discriminator aft conv2:", x.size()) + x = self.norm1(x) + x = self.activation_fn(x) + x = x.reshape((x.shape[0], -1)) # SCH: [B, (C * T * W * H)] ? + + # print("discriminator reshape:", x.size()) + x = self.linear1(x) + # print("discriminator aft linear1:", x.size()) + + x = self.activation_fn(x) + x = self.linear2(x) + # print("discriminator aft linear2:", x.size()) + return x + + +def load_checkpoint_with_inflation(model, ckpt_path): + """ + pre-train using image, then inflate to 3D videos + """ + if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): + state_dict = find_model(ckpt_path) + with torch.no_grad(): + for key in state_dict: + if key in model: + # central inflation + if state_dict[key].size() == model[key][:, :, 0, :, :].size(): + # temporal dimension + val = torch.zeros_like(model[key]) + centre = int(model[key].size(2) // 2) + val[:, :, centre, :, :] = state_dict[key] + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print(f"Missing keys: {missing_keys}") + print(f"Unexpected keys: {unexpected_keys}") + else: + load_checkpoint(model, ckpt_path) # use the default function + + +@MODELS.register_module("DISCRIMINATOR_3D") +def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs): + model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init) + if from_pretrained is not None: + if use_pretrained: + if inflate_from_2d: + load_checkpoint_with_inflation(model, from_pretrained) + else: + load_checkpoint(model, from_pretrained, model_name="discriminator") + print("loaded discriminator") + else: + print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator") + + return model + + +@MODELS.register_module("N_Layer_DISCRIMINATOR_3D") +def DISCRIMINATOR_3D_N_Layer(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs): + model = NLayerDiscriminator3D( + input_nc=3, + n_layers=3, + ).apply(n_layer_disc_weights_init) + if from_pretrained is not None: + if use_pretrained: + if inflate_from_2d: + load_checkpoint_with_inflation(model, from_pretrained) + else: + load_checkpoint(model, from_pretrained, model_name="discriminator") + print("loaded discriminator") + else: + print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator") + + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/vae/losses.py b/src/videogen_hub/pipelines/opensora/opensora/models/vae/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0686a4446d1ec24568eb49333e762a256602c4 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/vae/losses.py @@ -0,0 +1,274 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from videogen_hub.pipelines.opensora.opensora.models.vae.lpips import LPIPS + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) + return d_loss + + +# from MAGVIT, used in place hof hinge_d_loss +def sigmoid_cross_entropy_with_logits(labels, logits): + # The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x))) + zeros = torch.zeros_like(logits, dtype=logits.dtype) + condition = logits >= zeros + relu_logits = torch.where(condition, logits, zeros) + neg_abs_logits = torch.where(condition, -logits, logits) + return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits)) + + +def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred): + assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0 + lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2)) + lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2)) + return lecam_loss + + +def gradient_penalty_fn(images, output): + gradients = torch.autograd.grad( + outputs=output, + inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradients = rearrange(gradients, "b ... -> b (...)") + return ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + +class VAELoss(nn.Module): + def __init__( + self, + logvar_init=0.0, + perceptual_loss_weight=0.1, + kl_loss_weight=0.000001, + device="cpu", + dtype="bf16", + ): + super().__init__() + + if type(dtype) == str: + if dtype == "bf16": + dtype = torch.bfloat16 + elif dtype == "fp16": + dtype = torch.float16 + else: + raise NotImplementedError(f"dtype: {dtype}") + + # KL Loss + self.kl_loss_weight = kl_loss_weight + # Perceptual Loss + self.perceptual_loss_fn = LPIPS().eval().to(device, dtype) + self.perceptual_loss_weight = perceptual_loss_weight + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + def forward( + self, + video, + recon_video, + posterior, + nll_weights=None, + no_perceptual=False, + ): + video = rearrange(video, "b c t h w -> (b t) c h w").contiguous() + recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous() + + # reconstruction loss + recon_loss = torch.abs(video - recon_video) + + # perceptual loss + if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0 and not no_perceptual: + # handle channels + channels = video.shape[1] + assert channels in {1, 3} + if channels == 1: + input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3) + recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3) + else: + input_vgg_input = video + recon_vgg_input = recon_video + + perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input) + recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss + + nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar + + weighted_nll_loss = nll_loss + if nll_weights is not None: + weighted_nll_loss = nll_weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + + # KL Loss + weighted_kl_loss = 0 + if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0: + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + weighted_kl_loss = kl_loss * self.kl_loss_weight + + return nll_loss, weighted_nll_loss, weighted_kl_loss + + +def adopt_weight(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + +class AdversarialLoss(nn.Module): + def __init__( + self, + discriminator_factor=1.0, + discriminator_start=50001, + generator_factor=0.5, + generator_loss_type="non-saturating", + ): + super().__init__() + self.discriminator_factor = discriminator_factor + self.discriminator_start = discriminator_start + self.generator_factor = generator_factor + self.generator_loss_type = generator_loss_type + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.generator_factor + return d_weight + + def forward( + self, + fake_logits, + nll_loss, + last_layer, + global_step, + is_training=True, + ): + # NOTE: following MAGVIT to allow non_saturating + assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"] + + if self.generator_loss_type == "hinge": + gen_loss = -torch.mean(fake_logits) + elif self.generator_loss_type == "non-saturating": + gen_loss = torch.mean( + sigmoid_cross_entropy_with_logits(labels=torch.ones_like(fake_logits), logits=fake_logits) + ) + else: + raise ValueError("Generator loss {} not supported".format(self.generator_loss_type)) + + if self.discriminator_factor is not None and self.discriminator_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, gen_loss, last_layer) + except RuntimeError: + assert not is_training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) + weighted_gen_loss = d_weight * disc_factor * gen_loss + + return weighted_gen_loss + + +class LeCamEMA: + def __init__(self, ema_real=0.0, ema_fake=0.0, decay=0.999, dtype=torch.bfloat16, device="cpu"): + self.decay = decay + self.ema_real = torch.tensor(ema_real).to(device, dtype) + self.ema_fake = torch.tensor(ema_fake).to(device, dtype) + + def update(self, ema_real, ema_fake): + self.ema_real = self.ema_real * self.decay + ema_real * (1 - self.decay) + self.ema_fake = self.ema_fake * self.decay + ema_fake * (1 - self.decay) + + def get(self): + return self.ema_real, self.ema_fake + + +class DiscriminatorLoss(nn.Module): + def __init__( + self, + discriminator_factor=1.0, + discriminator_start=50001, + discriminator_loss_type="non-saturating", + lecam_loss_weight=None, + gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost + ): + super().__init__() + + assert discriminator_loss_type in ["hinge", "vanilla", "non-saturating"] + self.discriminator_factor = discriminator_factor + self.discriminator_start = discriminator_start + self.lecam_loss_weight = lecam_loss_weight + self.gradient_penalty_loss_weight = gradient_penalty_loss_weight + self.discriminator_loss_type = discriminator_loss_type + + def forward( + self, + real_logits, + fake_logits, + global_step, + lecam_ema_real=None, + lecam_ema_fake=None, + real_video=None, + split="train", + ): + if self.discriminator_factor is not None and self.discriminator_factor > 0.0: + disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) + + if self.discriminator_loss_type == "hinge": + disc_loss = hinge_d_loss(real_logits, fake_logits) + elif self.discriminator_loss_type == "non-saturating": + if real_logits is not None: + real_loss = sigmoid_cross_entropy_with_logits( + labels=torch.ones_like(real_logits), logits=real_logits + ) + else: + real_loss = 0.0 + if fake_logits is not None: + fake_loss = sigmoid_cross_entropy_with_logits( + labels=torch.zeros_like(fake_logits), logits=fake_logits + ) + else: + fake_loss = 0.0 + disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss)) + elif self.discriminator_loss_type == "vanilla": + disc_loss = vanilla_d_loss(real_logits, fake_logits) + else: + raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.") + + weighted_d_adversarial_loss = disc_factor * disc_loss + + else: + weighted_d_adversarial_loss = 0 + + lecam_loss = torch.tensor(0.0) + if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0: + real_pred = torch.mean(real_logits) + fake_pred = torch.mean(fake_logits) + lecam_loss = lecam_reg(real_pred, fake_pred, lecam_ema_real, lecam_ema_fake) + lecam_loss = lecam_loss * self.lecam_loss_weight + + gradient_penalty = torch.tensor(0.0) + if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0: + assert real_video is not None + gradient_penalty = gradient_penalty_fn(real_video, real_logits) + gradient_penalty *= self.gradient_penalty_loss_weight + + return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty) diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/vae/lpips.py b/src/videogen_hub/pipelines/opensora/opensora/models/vae/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..e643cba1dd3481481a34fa9743cc98f6fb778f48 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/vae/lpips.py @@ -0,0 +1,167 @@ +import hashlib +import os +from collections import namedtuple + +import requests +import torch +import torch.nn as nn +from torchvision import models +from tqdm import tqdm + +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} + +CKPT_MAP = {"vgg_lpips": "vgg.pth"} + +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "pretrained_models/taming/modules/autoencoder/lpips") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + # print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) + self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/vae/utils.py b/src/videogen_hub/pipelines/opensora/opensora/models/vae/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0ba6b81513954ac2a5635850e0670f10c85f24 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/vae/utils.py @@ -0,0 +1,50 @@ +import numpy as np +import torch + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +class DiagonalGaussianDistribution(object): + def __init__( + self, + parameters, + deterministic=False, + ): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device, dtype=self.mean.dtype) + + def sample(self): + # torch.randn: standard normal distribution + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device, dtype=self.mean.dtype) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: # SCH: assumes other is a standard normal distribution + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3, 4], + ) + + def nll(self, sample, dims=[1, 2, 3, 4]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/vae/vae.py b/src/videogen_hub/pipelines/opensora/opensora/models/vae/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..35042d008f84a4267b87c3dee471486d9cf575b7 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/vae/vae.py @@ -0,0 +1,290 @@ +import os + +import torch +import torch.nn as nn +from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder +from einops import rearrange +from transformers import PretrainedConfig, PreTrainedModel + +from videogen_hub.pipelines.opensora.opensora.registry import MODELS, build_module +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import load_checkpoint + + +@MODELS.register_module() +class VideoAutoencoderKL(nn.Module): + def __init__( + self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None + ): + super().__init__() + self.module = AutoencoderKL.from_pretrained( + from_pretrained, + cache_dir=cache_dir, + local_files_only=local_files_only, + subfolder=subfolder, + ) + self.out_channels = self.module.config.latent_channels + self.patch_size = (1, 8, 8) + self.micro_batch_size = micro_batch_size + + def encode(self, x): + # x: (B, C, T, H, W) + B = x.shape[0] + x = rearrange(x, "B C T H W -> (B T) C H W") + + if self.micro_batch_size is None: + x = self.module.encode(x).latent_dist.sample().mul_(0.18215) + else: + # NOTE: cannot be used for training + bs = self.micro_batch_size + x_out = [] + for i in range(0, x.shape[0], bs): + x_bs = x[i : i + bs] + x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215) + x_out.append(x_bs) + x = torch.cat(x_out, dim=0) + x = rearrange(x, "(B T) C H W -> B C T H W", B=B) + return x + + def decode(self, x, **kwargs): + # x: (B, C, T, H, W) + B = x.shape[0] + x = rearrange(x, "B C T H W -> (B T) C H W") + if self.micro_batch_size is None: + x = self.module.decode(x / 0.18215).sample + else: + # NOTE: cannot be used for training + bs = self.micro_batch_size + x_out = [] + for i in range(0, x.shape[0], bs): + x_bs = x[i : i + bs] + x_bs = self.module.decode(x_bs / 0.18215).sample + x_out.append(x_bs) + x = torch.cat(x_out, dim=0) + x = rearrange(x, "(B T) C H W -> B C T H W", B=B) + return x + + def get_latent_size(self, input_size): + latent_size = [] + for i in range(3): + # assert ( + # input_size[i] is None or input_size[i] % self.patch_size[i] == 0 + # ), "Input size must be divisible by patch size" + latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) + return latent_size + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +@MODELS.register_module() +class VideoAutoencoderKLTemporalDecoder(nn.Module): + def __init__(self, from_pretrained=None, cache_dir=None, local_files_only=False): + super().__init__() + self.module = AutoencoderKLTemporalDecoder.from_pretrained( + from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only + ) + self.out_channels = self.module.config.latent_channels + self.patch_size = (1, 8, 8) + + def encode(self, x): + raise NotImplementedError + + def decode(self, x, **kwargs): + B, _, T = x.shape[:3] + x = rearrange(x, "B C T H W -> (B T) C H W") + x = self.module.decode(x / 0.18215, num_frames=T).sample + x = rearrange(x, "(B T) C H W -> B C T H W", B=B) + return x + + def get_latent_size(self, input_size): + latent_size = [] + for i in range(3): + # assert ( + # input_size[i] is None or input_size[i] % self.patch_size[i] == 0 + # ), "Input size must be divisible by patch size" + latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) + return latent_size + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class VideoAutoencoderPipelineConfig(PretrainedConfig): + model_type = "VideoAutoencoderPipeline" + + def __init__( + self, + vae_2d=None, + vae_temporal=None, + from_pretrained=None, + freeze_vae_2d=False, + cal_loss=False, + micro_frame_size=None, + shift=0.0, + scale=1.0, + **kwargs, + ): + self.vae_2d = vae_2d + self.vae_temporal = vae_temporal + self.from_pretrained = from_pretrained + self.freeze_vae_2d = freeze_vae_2d + self.cal_loss = cal_loss + self.micro_frame_size = micro_frame_size + self.shift = shift + self.scale = scale + super().__init__(**kwargs) + + +@MODELS.register_module() +class VideoAutoencoderPipeline(PreTrainedModel): + config_class = VideoAutoencoderPipelineConfig + + def __init__(self, config: VideoAutoencoderPipelineConfig): + super().__init__(config=config) + self.spatial_vae = build_module(config.vae_2d, MODELS) + self.temporal_vae = build_module(config.vae_temporal, MODELS) + self.cal_loss = config.cal_loss + self.micro_frame_size = config.micro_frame_size + self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0] + + if config.freeze_vae_2d: + for param in self.spatial_vae.parameters(): + param.requires_grad = False + + self.out_channels = self.temporal_vae.out_channels + + # normalization parameters + scale = torch.tensor(config.scale) + shift = torch.tensor(config.shift) + if len(scale.shape) > 0: + scale = scale[None, :, None, None, None] + if len(shift.shape) > 0: + shift = shift[None, :, None, None, None] + self.register_buffer("scale", scale) + self.register_buffer("shift", shift) + + def encode(self, x): + x_z = self.spatial_vae.encode(x) + + if self.micro_frame_size is None: + posterior = self.temporal_vae.encode(x_z) + z = posterior.sample() + else: + z_list = [] + for i in range(0, x_z.shape[2], self.micro_frame_size): + x_z_bs = x_z[:, :, i : i + self.micro_frame_size] + posterior = self.temporal_vae.encode(x_z_bs) + z_list.append(posterior.sample()) + z = torch.cat(z_list, dim=2) + + if self.cal_loss: + return z, posterior, x_z + else: + return (z - self.shift) / self.scale + + def decode(self, z, num_frames=None): + if not self.cal_loss: + z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) + + if self.micro_frame_size is None: + x_z = self.temporal_vae.decode(z, num_frames=num_frames) + x = self.spatial_vae.decode(x_z) + else: + x_z_list = [] + for i in range(0, z.size(2), self.micro_z_frame_size): + z_bs = z[:, :, i : i + self.micro_z_frame_size] + print("self.micro_frame_size", self.micro_frame_size) + print("num_frames", num_frames) + x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames)) + x_z_list.append(x_z_bs) + num_frames -= self.micro_frame_size + x_z = torch.cat(x_z_list, dim=2) + x = self.spatial_vae.decode(x_z) + + if self.cal_loss: + return x, x_z + else: + return x + + def forward(self, x): + assert self.cal_loss, "This method is only available when cal_loss is True" + z, posterior, x_z = self.encode(x) + x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2]) + return x_rec, x_z_rec, z, posterior, x_z + + def get_latent_size(self, input_size): + if self.micro_frame_size is None or input_size[0] is None: + return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) + else: + sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]] + sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size)) + sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size) + remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None] + if remain_temporal_size[0] > 0: + remain_size = self.temporal_vae.get_latent_size(remain_temporal_size) + sub_latent_size[0] += remain_size[0] + return sub_latent_size + + def get_temporal_last_layer(self): + return self.temporal_vae.decoder.conv_out.conv.weight + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +@MODELS.register_module() +def OpenSoraVAE_V1_2( + micro_batch_size=4, + micro_frame_size=17, + from_pretrained=None, + local_files_only=False, + freeze_vae_2d=False, + cal_loss=False, +): + vae_2d = dict( + type="VideoAutoencoderKL", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", + micro_batch_size=micro_batch_size, + local_files_only=local_files_only, + ) + vae_temporal = dict( + type="VAE_Temporal_SD", + from_pretrained=None, + ) + shift = (-0.10, 0.34, 0.27, 0.98) + scale = (3.85, 2.32, 2.33, 3.06) + kwargs = dict( + vae_2d=vae_2d, + vae_temporal=vae_temporal, + freeze_vae_2d=freeze_vae_2d, + cal_loss=cal_loss, + micro_frame_size=micro_frame_size, + shift=shift, + scale=scale, + ) + + if from_pretrained is not None and not os.path.isdir(from_pretrained): + model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs) + else: + config = VideoAutoencoderPipelineConfig(**kwargs) + model = VideoAutoencoderPipeline(config) + + if from_pretrained: + load_checkpoint(model, from_pretrained) + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/models/vae/vae_temporal.py b/src/videogen_hub/pipelines/opensora/opensora/models/vae/vae_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb47d961aae6e2266e2369d71b109da3f3e736b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/models/vae/vae_temporal.py @@ -0,0 +1,435 @@ +from typing import Tuple, Union + +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from videogen_hub.pipelines.opensora.opensora.registry import MODELS +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import load_checkpoint + +from videogen_hub.pipelines.opensora.opensora.models.vae.utils import DiagonalGaussianDistribution + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def pad_at_dim(t, pad, dim=-1): + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = (0, 0) * dims_from_right + return F.pad(t, (*zeros, *pad), mode="constant") + + +def exists(v): + return v is not None + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + pad_mode="constant", + strides=None, # allow custom stride + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = strides[0] if strides is not None else kwargs.pop("stride", 1) + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + stride = strides if strides is not None else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + def __init__( + self, + in_channels, # SCH: added + filters, + conv_fn, + activation_fn=nn.SiLU, + use_conv_shortcut=False, + num_groups=32, + ): + super().__init__() + self.in_channels = in_channels + self.filters = filters + self.activate = activation_fn() + self.use_conv_shortcut = use_conv_shortcut + + # SCH: MAGVIT uses GroupNorm by default + self.norm1 = nn.GroupNorm(num_groups, in_channels) + self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) + self.norm2 = nn.GroupNorm(num_groups, self.filters) + self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False) + if in_channels != filters: + if self.use_conv_shortcut: + self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) + else: + self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False) + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.activate(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.activate(x) + x = self.conv2(x) + if self.in_channels != self.filters: # SCH: ResBlock X->Y + residual = self.conv3(residual) + return x + residual + + +def get_activation_fn(activation): + if activation == "relu": + activation_fn = nn.ReLU + elif activation == "swish": + activation_fn = nn.SiLU + else: + raise NotImplementedError + return activation_fn + + +class Encoder(nn.Module): + """Encoder Blocks.""" + + def __init__( + self, + in_out_channels=4, + latent_embed_dim=512, # num channels for latent vector + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + num_groups=32, # for nn.GroupNorm + activation_fn="swish", + ): + super().__init__() + self.filters = filters + self.num_res_blocks = num_res_blocks + self.num_blocks = len(channel_multipliers) + self.channel_multipliers = channel_multipliers + self.temporal_downsample = temporal_downsample + self.num_groups = num_groups + self.embedding_dim = latent_embed_dim + + self.activation_fn = get_activation_fn(activation_fn) + self.activate = self.activation_fn() + self.conv_fn = CausalConv3d + self.block_args = dict( + conv_fn=self.conv_fn, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + num_groups=self.num_groups, + ) + + # first layer conv + self.conv_in = self.conv_fn( + in_out_channels, + filters, + kernel_size=(3, 3, 3), + bias=False, + ) + + # ResBlocks and conv downsample + self.block_res_blocks = nn.ModuleList([]) + self.conv_blocks = nn.ModuleList([]) + + filters = self.filters + prev_filters = filters # record for in_channels + for i in range(self.num_blocks): + filters = self.filters * self.channel_multipliers[i] + block_items = nn.ModuleList([]) + for _ in range(self.num_res_blocks): + block_items.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # update in_channels + self.block_res_blocks.append(block_items) + + if i < self.num_blocks - 1: + if self.temporal_downsample[i]: + t_stride = 2 if self.temporal_downsample[i] else 1 + s_stride = 1 + self.conv_blocks.append( + self.conv_fn( + prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride) + ) + ) + prev_filters = filters # update in_channels + else: + # if no t downsample, don't add since this does nothing for pipeline models + self.conv_blocks.append(nn.Identity(prev_filters)) # Identity + prev_filters = filters # update in_channels + + # last layer res block + self.res_blocks = nn.ModuleList([]) + for _ in range(self.num_res_blocks): + self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # update in_channels + + # MAGVIT uses Group Normalization + self.norm1 = nn.GroupNorm(self.num_groups, prev_filters) + + self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same") + + def forward(self, x): + x = self.conv_in(x) + + for i in range(self.num_blocks): + for j in range(self.num_res_blocks): + x = self.block_res_blocks[i][j](x) + if i < self.num_blocks - 1: + x = self.conv_blocks[i](x) + for i in range(self.num_res_blocks): + x = self.res_blocks[i](x) + + x = self.norm1(x) + x = self.activate(x) + x = self.conv2(x) + return x + + +class Decoder(nn.Module): + """Decoder Blocks.""" + + def __init__( + self, + in_out_channels=4, + latent_embed_dim=512, + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + num_groups=32, # for nn.GroupNorm + activation_fn="swish", + ): + super().__init__() + self.filters = filters + self.num_res_blocks = num_res_blocks + self.num_blocks = len(channel_multipliers) + self.channel_multipliers = channel_multipliers + self.temporal_downsample = temporal_downsample + self.num_groups = num_groups + self.embedding_dim = latent_embed_dim + self.s_stride = 1 + + self.activation_fn = get_activation_fn(activation_fn) + self.activate = self.activation_fn() + self.conv_fn = CausalConv3d + self.block_args = dict( + conv_fn=self.conv_fn, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + num_groups=self.num_groups, + ) + + filters = self.filters * self.channel_multipliers[-1] + prev_filters = filters + + # last conv + self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True) + + # last layer res block + self.res_blocks = nn.ModuleList([]) + for _ in range(self.num_res_blocks): + self.res_blocks.append(ResBlock(filters, filters, **self.block_args)) + + # ResBlocks and conv upsample + self.block_res_blocks = nn.ModuleList([]) + self.num_blocks = len(self.channel_multipliers) + self.conv_blocks = nn.ModuleList([]) + # reverse to keep track of the in_channels, but append also in a reverse direction + for i in reversed(range(self.num_blocks)): + filters = self.filters * self.channel_multipliers[i] + # resblock handling + block_items = nn.ModuleList([]) + for _ in range(self.num_res_blocks): + block_items.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # SCH: update in_channels + self.block_res_blocks.insert(0, block_items) # SCH: append in front + + # conv blocks with upsampling + if i > 0: + if self.temporal_downsample[i - 1]: + t_stride = 2 if self.temporal_downsample[i - 1] else 1 + # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 + self.conv_blocks.insert( + 0, + self.conv_fn( + prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3) + ), + ) + else: + self.conv_blocks.insert( + 0, + nn.Identity(prev_filters), + ) + + self.norm1 = nn.GroupNorm(self.num_groups, prev_filters) + + self.conv_out = self.conv_fn(filters, in_out_channels, 3) + + def forward(self, x): + x = self.conv1(x) + for i in range(self.num_res_blocks): + x = self.res_blocks[i](x) + for i in reversed(range(self.num_blocks)): + for j in range(self.num_res_blocks): + x = self.block_res_blocks[i][j](x) + if i > 0: + t_stride = 2 if self.temporal_downsample[i - 1] else 1 + x = self.conv_blocks[i - 1](x) + x = rearrange( + x, + "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", + ts=t_stride, + hs=self.s_stride, + ws=self.s_stride, + ) + + x = self.norm1(x) + x = self.activate(x) + x = self.conv_out(x) + return x + + +@MODELS.register_module() +class VAE_Temporal(nn.Module): + def __init__( + self, + in_out_channels=4, + latent_embed_dim=4, + embed_dim=4, + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(True, True, False), + num_groups=32, # for nn.GroupNorm + activation_fn="swish", + ): + super().__init__() + + self.time_downsample_factor = 2 ** sum(temporal_downsample) + # self.time_padding = self.time_downsample_factor - 1 + self.patch_size = (self.time_downsample_factor, 1, 1) + self.out_channels = in_out_channels + + # NOTE: following MAGVIT, conv in bias=False in encoder first conv + self.encoder = Encoder( + in_out_channels=in_out_channels, + latent_embed_dim=latent_embed_dim * 2, + filters=filters, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + temporal_downsample=temporal_downsample, + num_groups=num_groups, # for nn.GroupNorm + activation_fn=activation_fn, + ) + self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1) + + self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1) + self.decoder = Decoder( + in_out_channels=in_out_channels, + latent_embed_dim=latent_embed_dim, + filters=filters, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + temporal_downsample=temporal_downsample, + num_groups=num_groups, # for nn.GroupNorm + activation_fn=activation_fn, + ) + + def get_latent_size(self, input_size): + latent_size = [] + for i in range(3): + if input_size[i] is None: + lsize = None + elif i == 0: + time_padding = ( + 0 + if (input_size[i] % self.time_downsample_factor == 0) + else self.time_downsample_factor - input_size[i] % self.time_downsample_factor + ) + lsize = (input_size[i] + time_padding) // self.patch_size[i] + else: + lsize = input_size[i] // self.patch_size[i] + latent_size.append(lsize) + return latent_size + + def encode(self, x): + time_padding = ( + 0 + if (x.shape[2] % self.time_downsample_factor == 0) + else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor + ) + x = pad_at_dim(x, (time_padding, 0), dim=2) + encoded_feature = self.encoder(x) + moments = self.quant_conv(encoded_feature).to(x.dtype) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, num_frames=None): + time_padding = ( + 0 + if (num_frames % self.time_downsample_factor == 0) + else self.time_downsample_factor - num_frames % self.time_downsample_factor + ) + z = self.post_quant_conv(z) + x = self.decoder(z) + x = x[:, :, time_padding:] + return x + + def forward(self, x, sample_posterior=True): + posterior = self.encode(x) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + recon_video = self.decode(z, num_frames=x.shape[2]) + return recon_video, posterior, z + + +@MODELS.register_module("VAE_Temporal_SD") +def VAE_Temporal_SD(from_pretrained=None, **kwargs): + model = VAE_Temporal( + in_out_channels=4, + latent_embed_dim=4, + embed_dim=4, + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/src/videogen_hub/pipelines/opensora/opensora/registry.py b/src/videogen_hub/pipelines/opensora/opensora/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2785cc0ecfab0fa1f2cb6d210c8c37b9675ce7 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/registry.py @@ -0,0 +1,46 @@ +from copy import deepcopy + +import torch.nn as nn +from mmengine.registry import Registry + + +def build_module(module, builder, **kwargs): + """Build module from config or return the module itself. + + Args: + module (Union[dict, nn.Module]): The module to build. + builder (Registry): The registry to build module. + *args, **kwargs: Arguments passed to build function. + + Returns: + Any: The built module. + """ + if module is None: + return None + if isinstance(module, dict): + cfg = deepcopy(module) + for k, v in kwargs.items(): + cfg[k] = v + return builder.build(cfg) + elif isinstance(module, nn.Module): + return module + elif module is None: + return None + else: + raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.") + + +MODELS = Registry( + "model", + locations=["videogen_hub.pipelines.opensora.opensora.models"], +) + +SCHEDULERS = Registry( + "scheduler", + locations=["videogen_hub.pipelines.opensora.opensora.schedulers"], +) + +DATASETS = Registry( + "dataset", + locations=["videogen_hub.pipelines.opensora.opensora.datasets"], +) diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4828c6cc22e84ff4927e295f4d2e29809c910939 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/__init__.py @@ -0,0 +1,3 @@ +from videogen_hub.pipelines.opensora.opensora.schedulers.dpms import DPMS +from videogen_hub.pipelines.opensora.opensora.schedulers.iddpm import IDDPM +from videogen_hub.pipelines.opensora.opensora.schedulers.rf import RFLOW diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/dpms/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/dpms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8298ee582796bdd884271f915061da7e39ac19d0 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/dpms/__init__.py @@ -0,0 +1,56 @@ +from functools import partial + +from videogen_hub.pipelines.opensora.opensora.schedulers.dpms.dpm_solver import DPMS +from videogen_hub.pipelines.opensora.opensora.registry import SCHEDULERS + + +@SCHEDULERS.register_module("dpm-solver") +class DPM_SOLVER: + def __init__(self, num_sampling_steps=None, cfg_scale=4.0): + self.num_sampling_steps = num_sampling_steps + self.cfg_scale = cfg_scale + + def sample( + self, + model, + text_encoder, + z, + prompts, + device, + additional_args=None, + mask=None, + progress=True, + ): + assert mask is None, "mask is not supported in dpm-solver" + n = len(prompts) + model_args = text_encoder.encode(prompts) + y = model_args.pop("y") + null_y = text_encoder.null(n) + if additional_args is not None: + model_args.update(additional_args) + + dpms = DPMS( + partial(forward_with_dpmsolver, model), + condition=y, + uncondition=null_y, + cfg_scale=self.cfg_scale, + model_kwargs=model_args, + ) + samples = dpms.sample( + z, + steps=self.num_sampling_steps, + order=2, + skip_type="time_uniform", + method="multistep", + progress=progress, + ) + return samples + + +def forward_with_dpmsolver(self, x, timestep, y, **kwargs): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(x, timestep, y, **kwargs) + return model_out.chunk(2, dim=1)[0] diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/dpms/dpm_solver.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/dpms/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..d422a0ae4ad99b0764d59b5aba0b0a04135de9cc --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/dpms/dpm_solver.py @@ -0,0 +1,1572 @@ +# MIT License +# +# Copyright (c) 2022 Cheng Lu +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# +# This file is adapted from the dpm-solver project +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# dpm-solver: https://github.com/LuChengTHU/dpm-solver +# -------------------------------------------------------- + +import math + +import numpy as np +import torch +from tqdm import tqdm + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class NoiseScheduleVP: + def __init__( + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ["discrete", "linear"]: + raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear'") + + self.schedule = schedule + if schedule == "discrete": + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1.0 + self.log_alpha_array = ( + self.numerical_clip_alpha(log_alphas) + .reshape( + ( + 1, + -1, + ) + ) + .to(dtype=dtype) + ) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1.0 + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device) + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1.0 or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1.0, + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == "logSNR": + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == "time_uniform": + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == "time_quadratic": + t_order = 2 + return torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) + else: + raise ValueError( + f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'" + ) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] + elif steps % 3 == 1: + orders = [ + 3, + ] * ( + K - 1 + ) + [1] + else: + orders = [ + 3, + ] * ( + K - 1 + ) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [ + 2, + ] * K + else: + K = steps // 2 + 1 + orders = [ + 2, + ] * ( + K - 1 + ) + [1] + elif order == 1: + K = 1 + orders = [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == "logSNR": + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ), + 0, + ).to(device) + ] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s + return (x_t, {"model_s": model_s}) if return_intermediate else x_t + + def singlestep_dpm_solver_second_update( + self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver" + ): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}") + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}") + if r1 is None: + r1 = 1.0 / 3.0 + if r2 is None: + r2 = 2.0 / 3.0 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}") + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.0)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == "taylor": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.0)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + return ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + return ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + def singlestep_dpm_solver_update( + self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None + ): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1 + ) + elif order == 3: + return self.singlestep_dpm_solver_third_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2 + ) + else: + raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}") + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}") + + def dpm_solver_adaptive( + self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver" + ): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, solver_type=solver_type, **kwargs + ) + elif order == 3: + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( + x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs + ) + else: + raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}") + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.0): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s) + nfe += order + print("adaptive solver nfe", nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + return xt.squeeze(0) if t.shape[0] == 1 else xt + + def inverse( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample( + x, + steps=steps, + t_start=t_0, + t_end=t_T, + order=order, + skip_type=skip_type, + method=method, + lower_order_final=lower_order_final, + denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, + rtol=rtol, + return_intermediate=return_intermediate, + ) + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + progress=True, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == "adaptive": + x = self.dpm_solver_adaptive( + x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type + ) + elif method == "multistep": + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, t, step, solver_type=solver_type + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + progress_fn = tqdm if progress else lambda x: x + for step in progress_fn(range(order, steps + 1)): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device + ) + elif method == "singlestep_fixed": + K = steps // order + orders = [ + order, + ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps( + skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device + ) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError(f"Got wrong method {method}") + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + return (x, intermediates) if return_intermediate else x + + +############################################################# +# other utility functions +############################################################# + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] + + +def DPMS( + model, + condition, + uncondition, + cfg_scale, + model_type="noise", + noise_schedule="linear", + guidance_type="classifier-free", + model_kwargs=None, + diffusion_steps=1000, +): + if model_kwargs is None: + model_kwargs = {} + betas = torch.tensor(get_named_beta_schedule(noise_schedule, diffusion_steps)) + + ## 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas) + + ## 2. Convert your discrete-time `model` to the continuous-time + ## noise prediction model. Here is an example for a diffusion model + ## `model` with the noise prediction type ("noise") . + model_fn = model_wrapper( + model, + noise_schedule, + model_type=model_type, + model_kwargs=model_kwargs, + guidance_type=guidance_type, + condition=condition, + unconditional_condition=uncondition, + guidance_scale=cfg_scale, + ) + ## 3. Define dpm-solver and sample by multistep DPM-Solver. + return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..600ec1df1f0d21502734b24c0823415c1c71b66f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/__init__.py @@ -0,0 +1,103 @@ +from functools import partial + +import torch + + +from videogen_hub.pipelines.opensora.opensora.schedulers.iddpm import gaussian_diffusion as gd +from videogen_hub.pipelines.opensora.opensora.schedulers.iddpm.respace import SpacedDiffusion, space_timesteps +from videogen_hub.pipelines.opensora.opensora.schedulers.iddpm.speed import SpeeDiffusion +from videogen_hub.pipelines.opensora.opensora.registry import SCHEDULERS + + +@SCHEDULERS.register_module("iddpm") +class IDDPM(SpacedDiffusion): + def __init__( + self, + num_sampling_steps=None, + timestep_respacing=None, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + cfg_scale=4.0, + cfg_channel=None, + ): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if num_sampling_steps is not None: + assert timestep_respacing is None + timestep_respacing = str(num_sampling_steps) + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + super().__init__( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=( + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + ) + + self.cfg_scale = cfg_scale + self.cfg_channel = cfg_channel + + def sample( + self, + model, + text_encoder, + z, + prompts, + device, + additional_args=None, + mask=None, + progress=True, + ): + n = len(prompts) + z = torch.cat([z, z], 0) + model_args = text_encoder.encode(prompts) + y_null = text_encoder.null(n) + model_args["y"] = torch.cat([model_args["y"], y_null], 0) + if additional_args is not None: + model_args.update(additional_args) + forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale, cfg_channel=self.cfg_channel) + samples = self.p_sample_loop( + forward, + z.shape, + z, + clip_denoised=False, + model_kwargs=model_args, + progress=progress, + device=device, + mask=mask, + ) + samples, _ = samples.chunk(2, dim=0) + return samples + + +def forward_with_cfg(model, x, timestep, y, cfg_scale, cfg_channel=None, **kwargs): + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + if "x_mask" in kwargs and kwargs["x_mask"] is not None: + if len(kwargs["x_mask"]) != len(x): + kwargs["x_mask"] = torch.cat([kwargs["x_mask"], kwargs["x_mask"]], dim=0) + model_out = model.forward(combined, timestep, y, **kwargs) + model_out = model_out["x"] if isinstance(model_out, dict) else model_out + if cfg_channel is None: + cfg_channel = model_out.shape[1] // 2 + eps, rest = model_out[:, :cfg_channel], model_out[:, cfg_channel:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/diffusion_utils.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e15d3c5d76ff16d62778299520a5c357eea7784 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/diffusion_utils.py @@ -0,0 +1,89 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + + +import numpy as np +import torch + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / torch.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = torch.distributions.Normal(torch.zeros_like(x), torch.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x < -0.999, + log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/gaussian_diffusion.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..93505508c5c181228f2c0c09a8c45614decdb1eb --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/gaussian_diffusion.py @@ -0,0 +1,904 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + +import enum +from typing import Callable, List + +import numpy as np +import torch +from einops import rearrange + +from videogen_hub.pipelines.opensora.opensora.schedulers.iddpm.diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor: torch.Tensor, mask=None): + """ + Take the mean over all non-batch dimensions. + """ + if mask is None: + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + else: + assert tensor.dim() == 5 + assert tensor.shape[2] == mask.shape[1] + tensor = rearrange(tensor, "b c t h w -> b t (c h w)") + denom = mask.sum(dim=1) * tensor.shape[-1] + loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom + return loss + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start: float, beta_end: float, num_diffusion_timesteps: int, warmup_frac: float) -> torch.Tensor: + betas = beta_end * torch.ones(num_diffusion_timesteps, dtype=torch.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = torch.linspace(beta_start, beta_end, warmup_time, dtype=torch.float64) + return betas + + +def get_beta_schedule( + beta_schedule: str, *, beta_start: float, beta_end: float, num_diffusion_timesteps: int +) -> torch.Tensor: + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=torch.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = torch.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=torch.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * torch.ones(num_diffusion_timesteps, dtype=torch.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / torch.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=torch.float64) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def betas_for_alpha_bar(num_diffusion_timesteps: int, alpha_bar: Callable, max_beta: float = 0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.DoubleTensor(betas) + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: torch.cos((t + 0.008) / 1.008 * torch.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas: torch.Tensor, + model_mean_type: str, + model_var_type: str, + loss_type: str, + device: str = "cuda", + ): + if device == "cuda": + device = torch.device(f"cuda:{torch.cuda.current_device()}") + elif device == "cpu": + device = torch.device("cpu") + else: + raise ValueError(f"Unknown device: {device}") + self.device = device + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + self.betas = betas.to(self.device) + assert len(self.betas.shape) == 1, "betas must be 1-D" + assert (self.betas > 0).all() and (self.betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=self.device), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], torch.tensor([0.0], device=self.device)]) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = ( + torch.log(torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]])) + if len(self.posterior_variance) > 1 + else torch.DoubleTensor([]) + ) + + self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = torch.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(torch.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = torch.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + torch.cat(self.posterior_variance[1].unsqueeze(0), self.betas[1:]), + torch.log(torch.cat(self.posterior_variance[1].unsqueeze(0), self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + mask=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + if mask is not None: + if mask.shape[0] != x.shape[0]: + mask = mask.repeat(2, 1) # HACK + mask_t = (mask * len(self.betas)).to(torch.int) + + # x0: copy unchanged x values + # x_noise: add noise to x values + x0 = x.clone() + x_noise = x0 * _extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) + torch.randn_like( + x + ) * _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) + + # active noise addition + # WARNING: this is a hacky implementation + mask_t_equall = (mask_t == t.unsqueeze(1))[:, None, :, None, None] + x = torch.where(mask_t_equall, x_noise, x0) + + # create x_mask + mask_t_upper = (mask_t > t.unsqueeze(1))[:, None, :, None, None] + batch_size = x.shape[0] + model_kwargs["x_mask"] = mask_t_upper.reshape(batch_size, -1).to(torch.bool) + + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = torch.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise + + if mask is not None: + mask_t_lower = (mask_t < t.unsqueeze(1))[:, None, :, None, None] + sample = torch.where(mask_t_lower, x0, sample) + + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + mask=mask, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = torch.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = torch.tensor([i] * shape[0], device=device) + with torch.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + mask=mask, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_bar / alpha_bar_prev) + # Equation 12. + noise = torch.randn_like(x) + mean_pred = out["pred_xstart"] * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * torch.sqrt(alpha_bar_next) + torch.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = torch.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = torch.tensor([i] * shape[0], device=device) + with torch.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None, mask=None): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl, mask=mask) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll, mask=mask) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = torch.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + # sample timestep + t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device) + + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = torch.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + if mask is not None: + t0 = torch.zeros_like(t) + x_t0 = self.q_sample(x_start, t0, noise=noise) + x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + assert mask is None, "mask not supported for KL loss" + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = torch.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + mask=mask, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + if weights is None: + terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask) + else: + weight = _extract_into_tensor(weights, t, target.shape) + terms["mse"] = mean_flat(weight * (target - model_output) ** 2, mask=mask) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = torch.tensor([t] * batch_size, device=device) + noise = torch.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with torch.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = torch.stack(vb, dim=1) + xstart_mse = torch.stack(xstart_mse, dim=1) + mse = torch.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr: torch.Tensor, timesteps: torch.Tensor, broadcast_shape: List[int]): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = arr.to(timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + torch.zeros(broadcast_shape, device=timesteps.device) diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/respace.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..88f074f55fb00a07d5226a36d92f57c3f0714b3b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/respace.py @@ -0,0 +1,130 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + + +import torch + +from videogen_hub.pipelines.opensora.opensora.schedulers.iddpm.gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = torch.FloatTensor(new_betas) + super().__init__(**kwargs) + try: + from colossalai.utils import get_current_device + except ImportError: + get_current_device = lambda: torch.device("cuda") + self.map_tensor = torch.tensor(self.timestep_map, device=get_current_device()) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.map_tensor, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, map_tensor, original_num_steps): + self.model = model + self.map_tensor = map_tensor + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + new_ts = self.map_tensor[ts].to(device=ts.device, dtype=ts.dtype) + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/speed.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/speed.py new file mode 100644 index 0000000000000000000000000000000000000000..8213500f35c8b6f162ff5512dd01d7736ffb8107 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/speed.py @@ -0,0 +1,75 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +from videogen_hub.pipelines.opensora.opensora.schedulers.iddpm import gaussian_diffusion as gd +from videogen_hub.pipelines.opensora.opensora.schedulers.iddpm.respace import SpacedDiffusion, space_timesteps +from videogen_hub.pipelines.opensora.opensora.registry import SCHEDULERS + + +@SCHEDULERS.register_module("iddpm-speed") +class SpeeDiffusion(SpacedDiffusion): + def __init__( + self, + num_sampling_steps=None, + timestep_respacing=None, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + cfg_scale=4.0, + ): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if num_sampling_steps is not None: + assert timestep_respacing is None + timestep_respacing = str(num_sampling_steps) + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + super().__init__( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=( + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + ) + + self.cfg_scale = cfg_scale + # we fallback to numpy here as argmax_cuda is not implemented for Bool + grad = np.gradient(self.sqrt_one_minus_alphas_cumprod.cpu()) + self.meaningful_steps = np.argmax(grad < 5e-5) + 1 + + # p2 weighting from: Perception Prioritized Training of Diffusion Models + self.p2_gamma = 1 + self.p2_k = 1 + self.snr = 1.0 / (1 - self.alphas_cumprod) - 1 + sqrt_one_minus_alphas_bar = self.sqrt_one_minus_alphas_cumprod + p = torch.tanh(1e6 * (torch.gradient(sqrt_one_minus_alphas_bar)[0] - 1e-4)) + 1.5 + self.p = F.normalize(p, p=1, dim=0) + self.weights = 1 / (self.p2_k + self.snr) ** self.p2_gamma + + def t_sample(self, n, device): + t = torch.multinomial(self.p, n // 2 + 1, replacement=True).to(device) + dual_t = torch.where(t < self.meaningful_steps, self.meaningful_steps - t, t - self.meaningful_steps) + t = torch.cat([t, dual_t], dim=0)[:n] + return t + + def training_losses(self, model, x, *args, **kwargs): # pylint: disable=signature-differs + t = self.t_sample(x.shape[0], x.device) + return super().training_losses(model, x, t, weights=self.weights, *args, **kwargs) + + def sample(self, *args, **kwargs): + raise NotImplementedError("SpeeDiffusion is only for training") diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/timestep_sampler.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..52b6717d528f398cd08f34c347b7fb69f4d5a9a3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/iddpm/timestep_sampler.py @@ -0,0 +1,150 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/rf/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/rf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..869574d106774d7d4d63ffad62dafe00ae38414a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/rf/__init__.py @@ -0,0 +1,103 @@ +import torch +from tqdm import tqdm + + +from videogen_hub.pipelines.opensora.opensora.schedulers.rf.rectified_flow import RFlowScheduler, timestep_transform +from videogen_hub.pipelines.opensora.opensora.registry import SCHEDULERS + + +@SCHEDULERS.register_module("rflow") +class RFLOW: + def __init__( + self, + num_sampling_steps=10, + num_timesteps=1000, + cfg_scale=4.0, + use_discrete_timesteps=False, + use_timestep_transform=False, + **kwargs, + ): + self.num_sampling_steps = num_sampling_steps + self.num_timesteps = num_timesteps + self.cfg_scale = cfg_scale + self.use_discrete_timesteps = use_discrete_timesteps + self.use_timestep_transform = use_timestep_transform + + self.scheduler = RFlowScheduler( + num_timesteps=num_timesteps, + num_sampling_steps=num_sampling_steps, + use_discrete_timesteps=use_discrete_timesteps, + use_timestep_transform=use_timestep_transform, + **kwargs, + ) + + def sample( + self, + model, + text_encoder, + z, + prompts, + device, + additional_args=None, + mask=None, + guidance_scale=None, + progress=True, + ): + # if no specific guidance scale is provided, use the default scale when initializing the scheduler + if guidance_scale is None: + guidance_scale = self.cfg_scale + + n = len(prompts) + # text encoding + model_args = text_encoder.encode(prompts) + y_null = text_encoder.null(n) + model_args["y"] = torch.cat([model_args["y"], y_null], 0) + if additional_args is not None: + model_args.update(additional_args) + + # prepare timesteps + timesteps = [(1.0 - i / self.num_sampling_steps) * self.num_timesteps for i in range(self.num_sampling_steps)] + if self.use_discrete_timesteps: + timesteps = [int(round(t)) for t in timesteps] + timesteps = [torch.tensor([t] * z.shape[0], device=device) for t in timesteps] + if self.use_timestep_transform: + timesteps = [timestep_transform(t, additional_args, num_timesteps=self.num_timesteps) for t in timesteps] + + if mask is not None: + noise_added = torch.zeros_like(mask, dtype=torch.bool) + noise_added = noise_added | (mask == 1) + + progress_wrap = tqdm if progress else (lambda x: x) + for i, t in progress_wrap(enumerate(timesteps)): + # mask for adding noise + if mask is not None: + mask_t = mask * self.num_timesteps + x0 = z.clone() + x_noise = self.scheduler.add_noise(x0, torch.randn_like(x0), t) + + mask_t_upper = mask_t >= t.unsqueeze(1) + model_args["x_mask"] = mask_t_upper.repeat(2, 1) + mask_add_noise = mask_t_upper & ~noise_added + + z = torch.where(mask_add_noise[:, None, :, None, None], x_noise, x0) + noise_added = mask_t_upper + + # classifier-free guidance + z_in = torch.cat([z, z], 0) + t = torch.cat([t, t], 0) + pred = model(z_in, t, **model_args).chunk(2, dim=1)[0] + pred_cond, pred_uncond = pred.chunk(2, dim=0) + v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + # update z + dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i] + dt = dt / self.num_timesteps + z = z + v_pred * dt[:, None, None, None, None] + + if mask is not None: + z = torch.where(mask_t_upper[:, None, :, None, None], z, x0) + + return z + + def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None): + return self.scheduler.training_losses(model, x_start, model_kwargs, noise, mask, weights, t) diff --git a/src/videogen_hub/pipelines/opensora/opensora/schedulers/rf/rectified_flow.py b/src/videogen_hub/pipelines/opensora/opensora/schedulers/rf/rectified_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..4c07416d908a192832e8f088dab8fdfa7637879d --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/schedulers/rf/rectified_flow.py @@ -0,0 +1,124 @@ +import torch +from torch.distributions import LogisticNormal + +from videogen_hub.pipelines.opensora.opensora.schedulers.iddpm.gaussian_diffusion import _extract_into_tensor, mean_flat + +# some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py +# and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py + + +def timestep_transform( + t, + model_kwargs, + base_resolution=512 * 512, + base_num_frames=1, + scale=1.0, + num_timesteps=1, +): + t = t / num_timesteps + resolution = model_kwargs["height"] * model_kwargs["width"] + ratio_space = (resolution / base_resolution).sqrt() + # NOTE: currently, we do not take fps into account + # NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae + if model_kwargs["num_frames"][0] == 1: + num_frames = torch.ones_like(model_kwargs["num_frames"]) + else: + num_frames = model_kwargs["num_frames"] // 17 * 5 + ratio_time = (num_frames / base_num_frames).sqrt() + + ratio = ratio_space * ratio_time * scale + new_t = ratio * t / (1 + (ratio - 1) * t) + + new_t = new_t * num_timesteps + return new_t + + +class RFlowScheduler: + def __init__( + self, + num_timesteps=1000, + num_sampling_steps=10, + use_discrete_timesteps=False, + sample_method="uniform", + loc=0.0, + scale=1.0, + use_timestep_transform=False, + transform_scale=1.0, + ): + self.num_timesteps = num_timesteps + self.num_sampling_steps = num_sampling_steps + self.use_discrete_timesteps = use_discrete_timesteps + + # sample method + assert sample_method in ["uniform", "logit-normal"] + assert ( + sample_method == "uniform" or not use_discrete_timesteps + ), "Only uniform sampling is supported for discrete timesteps" + self.sample_method = sample_method + if sample_method == "logit-normal": + self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) + self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) + + # timestep transform + self.use_timestep_transform = use_timestep_transform + self.transform_scale = transform_scale + + def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None): + """ + Compute training losses for a single timestep. + Arguments format copied /schedulers/iddpm/gaussian_diffusion.py/training_losses + Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0] + """ + if t is None: + if self.use_discrete_timesteps: + t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device) + elif self.sample_method == "uniform": + t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_timesteps + elif self.sample_method == "logit-normal": + t = self.sample_t(x_start) * self.num_timesteps + + if self.use_timestep_transform: + t = timestep_transform(t, model_kwargs, scale=self.transform_scale, num_timesteps=self.num_timesteps) + + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape + + x_t = self.add_noise(x_start, noise, t) + if mask is not None: + t0 = torch.zeros_like(t) + x_t0 = self.add_noise(x_start, noise, t0) + x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0) + + terms = {} + model_output = model(x_t, t, **model_kwargs) + velocity_pred = model_output.chunk(2, dim=1)[0] + if weights is None: + loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask=mask) + else: + weight = _extract_into_tensor(weights, t, x_start.shape) + loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask=mask) + terms["loss"] = loss + + return terms + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + """ + compatible with diffusers add_noise() + """ + timepoints = timesteps.float() / self.num_timesteps + timepoints = 1 - timepoints # [1,1/1000] + + # timepoint (bsz) noise: (bsz, 4, frame, w ,h) + # expand timepoint to noise shape + timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) + timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) + + return timepoints * original_samples + (1 - timepoints) * noise diff --git a/src/videogen_hub/pipelines/opensora/opensora/utils/__init__.py b/src/videogen_hub/pipelines/opensora/opensora/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/opensora/utils/ckpt_utils.py b/src/videogen_hub/pipelines/opensora/opensora/utils/ckpt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..18bebc6ca1ec071a3c5af2b95eccdbbf3fc810c9 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/utils/ckpt_utils.py @@ -0,0 +1,230 @@ +import functools +import operator +import os +from typing import Tuple + +import torch +import torch.distributed as dist +from torchvision.datasets.utils import download_url + +from videogen_hub.pipelines.opensora.opensora.utils.misc import get_logger + +hf_endpoint = os.environ.get("HF_ENDPOINT") +if hf_endpoint is None: + hf_endpoint = "https://huggingface.co" + +pretrained_models = { + "DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt", + "DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt", + "Latte-XL-2-256x256-ucf101.pt": hf_endpoint + "/maxin-cn/Latte/resolve/main/ucf101.pt", + "PixArt-XL-2-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth", + "PixArt-XL-2-SAM-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth", + "PixArt-XL-2-512x512.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth", + "PixArt-XL-2-1024-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth", + "OpenSora-v1-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth", + "OpenSora-v1-HQ-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth", + "OpenSora-v1-HQ-16x512x512.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth", + "PixArt-Sigma-XL-2-256x256.pth": hf_endpoint + + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth", + "PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint + + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth", + "PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint + + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth", + "PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth", +} + + +def reparameter(ckpt, name=None, model=None): + model_name = name + name = os.path.basename(name) + if not dist.is_initialized() or dist.get_rank() == 0: + get_logger().info("loading pretrained model: %s", model_name) + if name in ["DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"]: + ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) + del ckpt["pos_embed"] + if name in ["Latte-XL-2-256x256-ucf101.pt"]: + ckpt = ckpt["ema"] + ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) + del ckpt["pos_embed"] + del ckpt["temp_embed"] + if name in [ + "PixArt-XL-2-256x256.pth", + "PixArt-XL-2-SAM-256x256.pth", + "PixArt-XL-2-512x512.pth", + "PixArt-XL-2-1024-MS.pth", + "PixArt-Sigma-XL-2-256x256.pth", + "PixArt-Sigma-XL-2-512-MS.pth", + "PixArt-Sigma-XL-2-1024-MS.pth", + "PixArt-Sigma-XL-2-2K-MS.pth", + ]: + ckpt = ckpt["state_dict"] + ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) + if "pos_embed" in ckpt: + del ckpt["pos_embed"] + + if name in [ + "PixArt-1B-2.pth", + ]: + ckpt = ckpt["state_dict"] + if "pos_embed" in ckpt: + del ckpt["pos_embed"] + + # no need pos_embed + if "pos_embed_temporal" in ckpt: + del ckpt["pos_embed_temporal"] + if "pos_embed" in ckpt: + del ckpt["pos_embed"] + # different text length + if "y_embedder.y_embedding" in ckpt: + if ckpt["y_embedder.y_embedding"].shape[0] < model.y_embedder.y_embedding.shape[0]: + get_logger().info( + "Extend y_embedding from %s to %s", + ckpt["y_embedder.y_embedding"].shape[0], + model.y_embedder.y_embedding.shape[0], + ) + additional_length = model.y_embedder.y_embedding.shape[0] - ckpt["y_embedder.y_embedding"].shape[0] + new_y_embedding = torch.zeros(additional_length, model.y_embedder.y_embedding.shape[1]) + new_y_embedding[:] = ckpt["y_embedder.y_embedding"][-1] + ckpt["y_embedder.y_embedding"] = torch.cat([ckpt["y_embedder.y_embedding"], new_y_embedding], dim=0) + elif ckpt["y_embedder.y_embedding"].shape[0] > model.y_embedder.y_embedding.shape[0]: + get_logger().info( + "Shrink y_embedding from %s to %s", + ckpt["y_embedder.y_embedding"].shape[0], + model.y_embedder.y_embedding.shape[0], + ) + ckpt["y_embedder.y_embedding"] = ckpt["y_embedder.y_embedding"][: model.y_embedder.y_embedding.shape[0]] + # stdit3 special case + if type(model).__name__ == "STDiT3" and "PixArt-Sigma" in name: + ckpt_keys = list(ckpt.keys()) + for key in ckpt_keys: + if "blocks." in key: + ckpt[key.replace("blocks.", "spatial_blocks.")] = ckpt[key] + del ckpt[key] + + return ckpt + + +def find_model(model_name, model=None): + """ + Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints + model_ckpt = download_model(model_name) + model_ckpt = reparameter(model_ckpt, model_name, model=model) + else: # Load a custom DiT checkpoint: + assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}" + model_ckpt = torch.load(model_name, map_location=lambda storage, loc: storage) + model_ckpt = reparameter(model_ckpt, model_name, model=model) + return model_ckpt + + +def download_model(model_name=None, local_path=None, url=None): + """ + Downloads a pre-trained DiT model from the web. + """ + if model_name is not None: + assert model_name in pretrained_models + local_path = f"pretrained_models/{model_name}" + web_path = pretrained_models[model_name] + else: + assert local_path is not None + assert url is not None + web_path = url + if not os.path.isfile(local_path): + os.makedirs("pretrained_models", exist_ok=True) + dir_name = os.path.dirname(local_path) + file_name = os.path.basename(local_path) + download_url(web_path, dir_name, file_name) + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +def load_from_sharded_state_dict(model, ckpt_path, model_name="model", strict=False): + import os + from contextlib import suppress + + # Attempt to import colossalAI first + colossal_imported = False + with suppress(ImportError): + from colossalai.checkpoint import GeneralCheckpointIO + colossal_imported = True + + if not colossal_imported: + # Fall back to torch if colossalAI import fails + import torch + from torch import load as torch_load + + def load_model_with_fallback(model, ckpt_path): + if colossal_imported: + ckpt_io = GeneralCheckpointIO() + ckpt_io.load_model(model, ckpt_path) + elif os.path.exists(os.path.join(ckpt_path, 'model' + ".safetensors")): + import safetensors.torch + state_dict = safetensors.torch.load_file(os.path.join(ckpt_path, 'model' + ".safetensors")) + model.load_state_dict(state_dict) + else: + model.load_state_dict(torch_load(os.path.join(ckpt_path, 'model'))) + + print(os.getcwd()) + print(f"path={os.path.join(ckpt_path, 'model')}") + relative_path = os.path.join(ckpt_path, "model") + global_path = os.path.join(os.getcwd(), relative_path) + + # Load the model using the appropriate method + load_model_with_fallback(model, ckpt_path) + + +def model_sharding(model: torch.nn.Module): + global_rank = dist.get_rank() + world_size = dist.get_world_size() + for _, param in model.named_parameters(): + padding_size = (world_size - param.numel() % world_size) % world_size + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padding_param = param.data.view(-1) + splited_params = padding_param.split(padding_param.numel() // world_size) + splited_params = splited_params[global_rank] + param.data = splited_params + + +def model_gathering(model: torch.nn.Module, model_shape_dict: dict): + global_rank = dist.get_rank() + global_size = dist.get_world_size() + for name, param in model.named_parameters(): + all_params = [torch.empty_like(param.data) for _ in range(global_size)] + dist.all_gather(all_params, param.data, group=dist.group.WORLD) + if int(global_rank) == 0: + all_params = torch.cat(all_params) + param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name]) + dist.barrier() + + +def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor: + return tensor[: functools.reduce(operator.mul, original_shape)] + + +def record_model_param_shape(model: torch.nn.Module) -> dict: + param_shape = {} + for name, param in model.named_parameters(): + param_shape[name] = param.shape + return param_shape + + +def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", strict=False): + if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): + state_dict = find_model(ckpt_path, model=model) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict) + get_logger().info("Missing keys: %s", missing_keys) + get_logger().info("Unexpected keys: %s", unexpected_keys) + elif os.path.isdir(ckpt_path): + load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict) + get_logger().info("Model checkpoint loaded from %s", ckpt_path) + if save_as_pt: + save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt") + torch.save(model.state_dict(), save_path) + get_logger().info("Model checkpoint saved to %s", save_path) + else: + raise ValueError(f"Invalid checkpoint path: {ckpt_path}") + + diff --git a/src/videogen_hub/pipelines/opensora/opensora/utils/config_utils.py b/src/videogen_hub/pipelines/opensora/opensora/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f20138ba7911bd7facb6a00b726c55b9cdc7a14c --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/utils/config_utils.py @@ -0,0 +1,170 @@ +import argparse +import json +import os +from glob import glob + +from mmengine.config import Config + + +def parse_args(training=False): + parser = argparse.ArgumentParser() + + # model config + parser.add_argument("config", help="model config file path") + + # ====================================================== + # General + # ====================================================== + parser.add_argument("--seed", default=None, type=int, help="seed for reproducibility") + parser.add_argument( + "--ckpt-path", + default=None, + type=str, + help="path to model ckpt; will overwrite cfg.model.from_pretrained if specified", + ) + parser.add_argument("--batch-size", default=None, type=int, help="batch size") + parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights") + parser.add_argument("--flash-attn", default=None, type=str2bool, help="enable flash attention") + parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel") + parser.add_argument("--resolution", default=None, type=str, help="multi resolution") + parser.add_argument("--data-path", default=None, type=str, help="path to data csv") + parser.add_argument("--dtype", default=None, type=str, help="data type") + + # ====================================================== + # Inference + # ====================================================== + if not training: + # output + parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples") + parser.add_argument("--sample-name", default=None, type=str, help="sample name, default is sample_idx") + parser.add_argument("--start-index", default=None, type=int, help="start index for sample name") + parser.add_argument("--end-index", default=None, type=int, help="end index for sample name") + parser.add_argument("--num-sample", default=None, type=int, help="number of samples to generate for one prompt") + parser.add_argument("--prompt-as-path", action="store_true", help="use prompt as path to save samples") + parser.add_argument("--verbose", default=None, type=int, help="verbose level") + + # prompt + parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file") + parser.add_argument("--prompt", default=None, type=str, nargs="+", help="prompt list") + parser.add_argument("--llm-refine", default=None, type=str2bool, help="enable LLM refine") + parser.add_argument("--prompt-generator", default=None, type=str, help="prompt generator") + + # image/video + parser.add_argument("--num-frames", default=None, type=str, help="number of frames") + parser.add_argument("--fps", default=None, type=int, help="fps") + parser.add_argument("--save-fps", default=None, type=int, help="save fps") + parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size") + parser.add_argument("--frame-interval", default=None, type=int, help="frame interval") + parser.add_argument("--aspect-ratio", default=None, type=str, help="aspect ratio (h:w)") + parser.add_argument("--watermark", default=None, type=str2bool, help="watermark video") + + # hyperparameters + parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps") + parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond") + + # reference + parser.add_argument("--loop", default=None, type=int, help="loop") + parser.add_argument("--condition-frame-length", default=None, type=int, help="condition frame length") + parser.add_argument("--reference-path", default=None, type=str, nargs="+", help="reference path") + parser.add_argument("--mask-strategy", default=None, type=str, nargs="+", help="mask strategy") + parser.add_argument("--aes", default=None, type=float, help="aesthetic score") + parser.add_argument("--flow", default=None, type=float, help="flow score") + parser.add_argument("--camera-motion", default=None, type=str, help="camera motion") + # ====================================================== + # Training + # ====================================================== + else: + parser.add_argument("--lr", default=None, type=float, help="learning rate") + parser.add_argument("--wandb", default=None, type=bool, help="enable wandb") + parser.add_argument("--load", default=None, type=str, help="path to continue training") + parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch") + parser.add_argument("--warmup-steps", default=None, type=int, help="warmup steps") + + return parser.parse_args() + + +def merge_args(cfg, args, training=False): + if args.ckpt_path is not None: + cfg.model["from_pretrained"] = args.ckpt_path + if cfg.get("discriminator") is not None: + cfg.discriminator["from_pretrained"] = args.ckpt_path + args.ckpt_path = None + if args.flash_attn is not None: + cfg.model["enable_flash_attn"] = args.flash_attn + args.enable_flash_attn = None + if args.layernorm_kernel is not None: + cfg.model["enable_layernorm_kernel"] = args.layernorm_kernel + args.enable_layernorm_kernel = None + if args.data_path is not None: + cfg.dataset["data_path"] = args.data_path + args.data_path = None + # NOTE: for vae inference (reconstruction) + if not training and "dataset" in cfg: + if args.image_size is not None: + cfg.dataset["image_size"] = args.image_size + if args.num_frames is not None: + cfg.dataset["num_frames"] = args.num_frames + if not training: + if args.cfg_scale is not None: + cfg.scheduler["cfg_scale"] = args.cfg_scale + args.cfg_scale = None + if args.num_sampling_steps is not None: + cfg.scheduler["num_sampling_steps"] = args.num_sampling_steps + args.num_sampling_steps = None + + for k, v in vars(args).items(): + if v is not None: + cfg[k] = v + + return cfg + + +def read_config(config_path): + cfg = Config.fromfile(config_path) + return cfg + + +def parse_configs(training=False): + args = parse_args(training) + cfg = read_config(args.config) + cfg = merge_args(cfg, args, training) + return cfg + + +def define_experiment_workspace(cfg, get_last_workspace=False): + """ + This function creates a folder for experiment tracking. + + Args: + args: The parsed arguments. + + Returns: + exp_dir: The path to the experiment folder. + """ + # Make outputs folder (holds all experiment subfolders) + os.makedirs(cfg.outputs, exist_ok=True) + experiment_index = len(glob(f"{cfg.outputs}/*")) + if get_last_workspace: + experiment_index -= 1 + + # Create an experiment folder + model_name = cfg.model["type"].replace("/", "-") + exp_name = f"{experiment_index:03d}-{model_name}" + exp_dir = f"{cfg.outputs}/{exp_name}" + return exp_name, exp_dir + + +def save_training_config(cfg, experiment_dir): + with open(f"{experiment_dir}/config.txt", "w") as f: + json.dump(cfg, f, indent=4) + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/src/videogen_hub/pipelines/opensora/opensora/utils/inference_utils.py b/src/videogen_hub/pipelines/opensora/opensora/utils/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a5935855ae35befab4395529073df944b7b67bf3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/utils/inference_utils.py @@ -0,0 +1,322 @@ +import json +import os +import re + +import torch + +from videogen_hub.pipelines.opensora.opensora.datasets import IMG_FPS +from videogen_hub.pipelines.opensora.opensora.datasets.utils import read_from_path + + +def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype): + if info_type is None: + return dict() + elif info_type == "PixArtMS": + hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1) + ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1) + return dict(ar=ar, hw=hw) + elif info_type in ["STDiT2", "OpenSora"]: + fps = fps if num_frames > 1 else IMG_FPS + fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size) + height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size) + width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size) + num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size) + ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size) + return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps) + else: + raise NotImplementedError + + +def load_prompts(prompt_path, start_idx=None, end_idx=None): + with open(prompt_path, "r") as f: + prompts = [line.strip() for line in f.readlines()] + prompts = prompts[start_idx:end_idx] + return prompts + + +def get_save_path_name( + save_dir, + sample_name=None, # prefix + sample_idx=None, # sample index + prompt=None, # used prompt + prompt_as_path=False, # use prompt as path + num_sample=1, # number of samples to generate for one prompt + k=None, # kth sample +): + if sample_name is None: + sample_name = "" if prompt_as_path else "sample" + sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}" + save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}") + if num_sample != 1: + save_path = f"{save_path}-{k}" + return save_path + + +def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None): + new_prompts = [] + for prompt in prompts: + new_prompt = prompt + if aes is not None and "aesthetic score:" not in prompt: + new_prompt = f"{new_prompt} aesthetic score: {aes:.1f}." + if flow is not None and "motion score:" not in prompt: + new_prompt = f"{new_prompt} motion score: {flow:.1f}." + if camera_motion is not None and "camera motion:" not in prompt: + new_prompt = f"{new_prompt} camera motion: {camera_motion}." + new_prompts.append(new_prompt) + return new_prompts + + +def extract_json_from_prompts(prompts, reference, mask_strategy): + ret_prompts = [] + for i, prompt in enumerate(prompts): + parts = re.split(r"(?=[{])", prompt) + assert len(parts) <= 2, f"Invalid prompt: {prompt}" + ret_prompts.append(parts[0]) + if len(parts) > 1: + additional_info = json.loads(parts[1]) + for key in additional_info: + assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}" + if key == "reference_path": + reference[i] = additional_info[key] + elif key == "mask_strategy": + mask_strategy[i] = additional_info[key] + return ret_prompts, reference, mask_strategy + + +def collect_references_batch(reference_paths, vae, image_size): + refs_x = [] # refs_x: [batch, ref_num, C, T, H, W] + for reference_path in reference_paths: + if reference_path == "": + refs_x.append([]) + continue + ref_path = reference_path.split(";") + ref = [] + for r_path in ref_path: + r = read_from_path(r_path, image_size, transform_name="resize_crop") + r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype)) + r_x = r_x.squeeze(0) + ref.append(r_x) + refs_x.append(ref) + return refs_x + + +def extract_prompts_loop(prompts, num_loop): + ret_prompts = [] + for prompt in prompts: + if prompt.startswith("|0|"): + prompt_list = prompt.split("|")[1:] + text_list = [] + for i in range(0, len(prompt_list), 2): + start_loop = int(prompt_list[i]) + text = prompt_list[i + 1] + end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + 1 + text_list.extend([text] * (end_loop - start_loop)) + prompt = text_list[num_loop] + ret_prompts.append(prompt) + return ret_prompts + + +def split_prompt(prompt_text): + if prompt_text.startswith("|0|"): + # this is for prompts which look like + # |0| a beautiful day |1| a sunny day |2| a rainy day + # we want to parse it into a list of prompts with the loop index + prompt_list = prompt_text.split("|")[1:] + text_list = [] + loop_idx = [] + for i in range(0, len(prompt_list), 2): + start_loop = int(prompt_list[i]) + text = prompt_list[i + 1].strip() + text_list.append(text) + loop_idx.append(start_loop) + return text_list, loop_idx + else: + return [prompt_text], None + + +def merge_prompt(text_list, loop_idx_list=None): + if loop_idx_list is None: + return text_list[0] + else: + prompt = "" + for i, text in enumerate(text_list): + prompt += f"|{loop_idx_list[i]}|{text}" + return prompt + + +MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"] + + +def parse_mask_strategy(mask_strategy): + mask_batch = [] + if mask_strategy == "" or mask_strategy is None: + return mask_batch + + mask_strategy = mask_strategy.split(";") + for mask in mask_strategy: + mask_group = mask.split(",") + num_group = len(mask_group) + assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}" + mask_group.extend(MASK_DEFAULT[num_group:]) + for i in range(5): + mask_group[i] = int(mask_group[i]) + mask_group[5] = float(mask_group[5]) + mask_batch.append(mask_group) + return mask_batch + + +def find_nearest_point(value, point, max_value): + t = value // point + if value % point > point / 2 and t < max_value // point - 1: + t += 1 + return t * point + + +def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None): + masks = [] + no_mask = True + for i, mask_strategy in enumerate(mask_strategys): + no_mask = False + mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device) + mask_strategy = parse_mask_strategy(mask_strategy) + for mst in mask_strategy: + loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst + if loop_id != loop_i: + continue + ref = refs_x[i][m_id] + + if m_ref_start < 0: + # ref: [C, T, H, W] + m_ref_start = ref.shape[1] + m_ref_start + if m_target_start < 0: + # z: [B, C, T, H, W] + m_target_start = z.shape[2] + m_target_start + if align is not None: + m_ref_start = find_nearest_point(m_ref_start, align, ref.shape[1]) + m_target_start = find_nearest_point(m_target_start, align, z.shape[2]) + m_length = min(m_length, z.shape[2] - m_target_start, ref.shape[1] - m_ref_start) + z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length] + mask[m_target_start : m_target_start + m_length] = edit_ratio + masks.append(mask) + if no_mask: + return None + masks = torch.stack(masks) + return masks + + +def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i, condition_frame_length, condition_frame_edit): + ref_x = vae.encode(generated_video) + for j, refs in enumerate(refs_x): + if refs is None: + refs_x[j] = [ref_x[j]] + else: + refs.append(ref_x[j]) + if mask_strategy[j] is None or mask_strategy[j] == "": + mask_strategy[j] = "" + else: + mask_strategy[j] += ";" + mask_strategy[ + j + ] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}" + return refs_x, mask_strategy + + +def dframe_to_frame(num): + assert num % 5 == 0, f"Invalid num: {num}" + return num // 5 * 17 + + +OPENAI_CLIENT = None +REFINE_PROMPTS = None +REFINE_PROMPTS_PATH = "assets/texts/t2v_pllava.txt" +REFINE_PROMPTS_TEMPLATE = """ +You need to refine user's input prompt. The user's input prompt is used for video generation task. You need to refine the user's prompt to make it more suitable for the task. Here are some examples of refined prompts: +{} + +The refined prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The refined prompt should be in English. +""" +RANDOM_PROMPTS = None +RANDOM_PROMPTS_TEMPLATE = """ +You need to generate one input prompt for video generation task. The prompt should be suitable for the task. Here are some examples of refined prompts: +{} + +The prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The prompt should be in English. +""" + + +def get_openai_response(sys_prompt, usr_prompt, model="gpt-4o"): + global OPENAI_CLIENT + if OPENAI_CLIENT is None: + from openai import OpenAI + + OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + completion = OPENAI_CLIENT.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": sys_prompt, + }, # <-- This is the system message that provides context to the model + { + "role": "user", + "content": usr_prompt, + }, # <-- This is the user message for which the model will generate a response + ], + ) + + return completion.choices[0].message.content + + +def get_random_prompt_by_openai(): + global RANDOM_PROMPTS + if RANDOM_PROMPTS is None: + examples = load_prompts(REFINE_PROMPTS_PATH) + RANDOM_PROMPTS = RANDOM_PROMPTS_TEMPLATE.format("\n".join(examples)) + + response = get_openai_response(RANDOM_PROMPTS, "Generate one example.") + return response + + +def refine_prompt_by_openai(prompt): + global REFINE_PROMPTS + if REFINE_PROMPTS is None: + examples = load_prompts(REFINE_PROMPTS_PATH) + REFINE_PROMPTS = REFINE_PROMPTS_TEMPLATE.format("\n".join(examples)) + + response = get_openai_response(REFINE_PROMPTS, prompt) + return response + + +def has_openai_key(): + return "OPENAI_API_KEY" in os.environ + + +def refine_prompts_by_openai(prompts): + new_prompts = [] + for prompt in prompts: + try: + if prompt.strip() == "": + new_prompt = get_random_prompt_by_openai() + print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}") + else: + new_prompt = refine_prompt_by_openai(prompt) + print(f"[Info] Refine prompt: {prompt} -> {new_prompt}") + new_prompts.append(new_prompt) + except Exception as e: + print(f"[Warning] Failed to refine prompt: {prompt} due to {e}") + new_prompts.append(prompt) + return new_prompts + + +def add_watermark( + input_video_path, watermark_image_path="./assets/images/watermark/watermark.png", output_video_path=None +): + # execute this command in terminal with subprocess + # return if the process is successful + if output_video_path is None: + output_video_path = input_video_path.replace(".mp4", "_watermark.mp4") + cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}' + exit_code = os.system(cmd) + is_success = exit_code == 0 + return is_success diff --git a/src/videogen_hub/pipelines/opensora/opensora/utils/lr_scheduler.py b/src/videogen_hub/pipelines/opensora/opensora/utils/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..e0f75f521884a2a120d5c5f858317624d24b5e35 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/utils/lr_scheduler.py @@ -0,0 +1,22 @@ +from torch.optim.lr_scheduler import _LRScheduler + + +class LinearWarmupLR(_LRScheduler): + """Linearly warmup learning rate and then linearly decay. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_steps (int, optional): Number of warmup steps, defaults to 0 + last_step (int, optional): The index of last step, defaults to -1. When last_step=-1, + the schedule is started from the beginning or When last_step=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, warmup_steps: int = 0, last_epoch: int = -1): + self.warmup_steps = warmup_steps + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + if self.last_epoch < self.warmup_steps: + return [(self.last_epoch + 1) / (self.warmup_steps + 1) * lr for lr in self.base_lrs] + else: + return self.base_lrs diff --git a/src/videogen_hub/pipelines/opensora/opensora/utils/misc.py b/src/videogen_hub/pipelines/opensora/opensora/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f6730974d6e294cb90582c81a812078f48b609 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/opensora/utils/misc.py @@ -0,0 +1,411 @@ +import collections +import importlib +import logging +import os +import time +from collections import OrderedDict +from collections.abc import Sequence +from itertools import repeat +from typing import Tuple + +import numpy as np +import torch +import torch.distributed as dist + +# ====================================================== +# Logging +# ====================================================== + + +def is_distributed(): + return os.environ.get("WORLD_SIZE", None) is not None + + +def is_main_process(): + return not is_distributed() or dist.get_rank() == 0 + + +def get_world_size(): + if is_distributed(): + return dist.get_world_size() + else: + return 1 + + +def create_logger(logging_dir=None): + """ + Create a logger that writes to a log file and stdout. + """ + if is_main_process(): # real logger + additional_args = dict() + if logging_dir is not None: + additional_args["handlers"] = [ + logging.StreamHandler(), + logging.FileHandler(f"{logging_dir}/log.txt"), + ] + logging.basicConfig( + level=logging.INFO, + format="[\033[34m%(asctime)s\033[0m] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + **additional_args, + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def get_logger(): + return logging.getLogger(__name__) + + +def print_rank(var_name, var_value, rank=0): + if dist.get_rank() == rank: + print(f"[Rank {rank}] {var_name}: {var_value}") + + +def print_0(*args, **kwargs): + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def create_tensorboard_writer(exp_dir): + from torch.utils.tensorboard import SummaryWriter + + tensorboard_dir = f"{exp_dir}/tensorboard" + os.makedirs(tensorboard_dir, exist_ok=True) + writer = SummaryWriter(tensorboard_dir) + return writer + + +# ====================================================== +# String +# ====================================================== + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" + + +def get_timestamp(): + timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time())) + return timestamp + + +def format_time(seconds): + days = int(seconds / 3600 / 24) + seconds = seconds - days * 3600 * 24 + hours = int(seconds / 3600) + seconds = seconds - hours * 3600 + minutes = int(seconds / 60) + seconds = seconds - minutes * 60 + secondsf = int(seconds) + seconds = seconds - secondsf + millis = int(seconds * 1000) + + f = "" + i = 1 + if days > 0: + f += str(days) + "D" + i += 1 + if hours > 0 and i <= 2: + f += str(hours) + "h" + i += 1 + if minutes > 0 and i <= 2: + f += str(minutes) + "m" + i += 1 + if secondsf > 0 and i <= 2: + f += str(secondsf) + "s" + i += 1 + if millis > 0 and i <= 2: + f += str(millis) + "ms" + i += 1 + if f == "": + f = "0ms" + return f + + +class BColors: + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKCYAN = "\033[96m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + + +# ====================================================== +# PyTorch +# ====================================================== + + +def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + +def count_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not isinstance(data, str): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + + +def to_ndarray(data): + if isinstance(data, torch.Tensor): + return data.numpy() + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, Sequence): + return np.array(data) + elif isinstance(data, int): + return np.ndarray([data], dtype=int) + elif isinstance(data, float): + return np.array([data], dtype=float) + else: + raise TypeError(f"type {type(data)} cannot be converted to ndarray.") + + +def to_torch_dtype(dtype): + if isinstance(dtype, torch.dtype): + return dtype + elif isinstance(dtype, str): + dtype_mapping = { + "float64": torch.float64, + "float32": torch.float32, + "float16": torch.float16, + "fp32": torch.float32, + "fp16": torch.float16, + "half": torch.float16, + "bf16": torch.bfloat16, + } + if dtype not in dtype_mapping: + raise ValueError + dtype = dtype_mapping[dtype] + return dtype + else: + raise ValueError + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def convert_SyncBN_to_BN2d(model_cfg): + for k in model_cfg: + v = model_cfg[k] + if k == "norm_cfg" and v["type"] == "SyncBN": + v["type"] = "BN2d" + elif isinstance(v, dict): + convert_SyncBN_to_BN2d(v) + + +def get_topk(x, dim=4, k=5): + x = to_tensor(x) + inds = x[..., dim].topk(k)[1] + return x[inds] + + +def param_sigmoid(x, alpha): + ret = 1 / (1 + (-alpha * x).exp()) + return ret + + +def inverse_param_sigmoid(x, alpha, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) / alpha + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +# ====================================================== +# Python +# ====================================================== + + +def count_columns(df, columns): + cnt_dict = OrderedDict() + num_samples = len(df) + + for col in columns: + d_i = df[col].value_counts().to_dict() + for k in d_i: + d_i[k] = (d_i[k], d_i[k] / num_samples) + cnt_dict[col] = d_i + + return cnt_dict + + +def try_import(name): + """Try to import a module. + + Args: + name (str): Specifies what module to import in absolute or relative + terms (e.g. either pkg.mod or ..mod). + Returns: + ModuleType or None: If importing successfully, returns the imported + module, otherwise returns None. + """ + try: + return importlib.import_module(name) + except ImportError: + return None + + +def transpose(x): + """ + transpose a list of list + Args: + x (list[list]): + """ + ret = list(map(list, zip(*x))) + return ret + + +def all_exists(paths): + return all(os.path.exists(path) for path in paths) + + +# ====================================================== +# Profile +# ====================================================== + + +class Timer: + def __init__(self, name, log=False): + self.name = name + self.start_time = None + self.end_time = None + self.log = log + + @property + def elapsed_time(self): + return self.end_time - self.start_time + + def __enter__(self): + torch.cuda.synchronize() + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.cuda.synchronize() + self.end_time = time.time() + if self.log: + print(f"Elapsed time for {self.name}: {self.elapsed_time:.2f} s") + + +def get_tensor_memory(tensor, human_readable=True): + size = tensor.element_size() * tensor.nelement() + if human_readable: + size = format_numel_str(size) + return size + + +class FeatureSaver: + def __init__(self, save_dir, bin_size=10, start_bin=0): + self.save_dir = save_dir + self.bin_size = bin_size + self.bin_cnt = start_bin + + self.data_list = [] + self.cnt = 0 + + def update(self, data): + self.data_list.append(data) + self.cnt += 1 + + if self.cnt % self.bin_size == 0: + self.save() + + def save(self): + save_path = os.path.join(self.save_dir, f"{self.bin_cnt:08}.bin") + torch.save(self.data_list, save_path) + get_logger().info("Saved to %s", save_path) + self.data_list = [] + self.bin_cnt += 1 diff --git a/src/videogen_hub/pipelines/opensora/requirements.txt b/src/videogen_hub/pipelines/opensora/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8031a810f0a78ad9431e86297296f8401f5ca86 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/requirements.txt @@ -0,0 +1,17 @@ +colossalai +accelerate +diffusers +ftfy +gdown +mmengine +pandas +pre-commit +pyarrow +pyav +tensorboard +timm +tqdm +transformers +wandb +rotary_embedding_torch +pandarallel diff --git a/src/videogen_hub/pipelines/opensora/scripts/__init__.py b/src/videogen_hub/pipelines/opensora/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/scripts/__init__.py @@ -0,0 +1 @@ + diff --git a/src/videogen_hub/pipelines/opensora/scripts/inference-long.py b/src/videogen_hub/pipelines/opensora/scripts/inference-long.py new file mode 100644 index 0000000000000000000000000000000000000000..617ee9d1ab557b51ac2744ba040383715747bba6 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/scripts/inference-long.py @@ -0,0 +1,325 @@ +import json +import os +import re + + +import torch +import torch.distributed as dist +from mmengine.runner import set_random_seed + +from videogen_hub.pipelines.opensora.opensora.acceleration.parallel_states import set_sequence_parallel_group +from videogen_hub.pipelines.opensora.opensora.datasets import IMG_FPS, save_sample +from videogen_hub.pipelines.opensora.opensora.datasets.utils import read_from_path +from videogen_hub.pipelines.opensora.opensora.models.text_encoder.t5 import text_preprocessing +from videogen_hub.pipelines.opensora.opensora.registry import MODELS, SCHEDULERS, build_module +from videogen_hub.pipelines.opensora.opensora.utils.config_utils import parse_configs +from videogen_hub.pipelines.opensora.opensora.utils.misc import to_torch_dtype + + +def collect_references_batch(reference_paths, vae, image_size): + refs_x = [] + for reference_path in reference_paths: + if reference_path is None: + refs_x.append([]) + continue + ref_path = reference_path.split(";") + ref = [] + for r_path in ref_path: + r = read_from_path(r_path, image_size, transform_name="resize_crop") + r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype)) + r_x = r_x.squeeze(0) + ref.append(r_x) + refs_x.append(ref) + # refs_x: [batch, ref_num, C, T, H, W] + return refs_x + + +def process_mask_strategy(mask_strategy): + mask_batch = [] + mask_strategy = mask_strategy.split(";") + for mask in mask_strategy: + mask_group = mask.split(",") + assert len(mask_group) >= 1 and len(mask_group) <= 6, f"Invalid mask strategy: {mask}" + if len(mask_group) == 1: + mask_group.extend(["0", "0", "0", "1", "0"]) + elif len(mask_group) == 2: + mask_group.extend(["0", "0", "1", "0"]) + elif len(mask_group) == 3: + mask_group.extend(["0", "1", "0"]) + elif len(mask_group) == 4: + mask_group.extend(["1", "0"]) + elif len(mask_group) == 5: + mask_group.append("0") + mask_batch.append(mask_group) + return mask_batch + + +def apply_mask_strategy(z, refs_x, mask_strategys, loop_i): + masks = [] + for i, mask_strategy in enumerate(mask_strategys): + mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device) + if mask_strategy is None: + masks.append(mask) + continue + mask_strategy = process_mask_strategy(mask_strategy) + for mst in mask_strategy: + loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst + loop_id = int(loop_id) + if loop_id != loop_i: + continue + m_id = int(m_id) + m_ref_start = int(m_ref_start) + m_length = int(m_length) + m_target_start = int(m_target_start) + edit_ratio = float(edit_ratio) + ref = refs_x[i][m_id] # [C, T, H, W] + if m_ref_start < 0: + m_ref_start = ref.shape[1] + m_ref_start + if m_target_start < 0: + # z: [B, C, T, H, W] + m_target_start = z.shape[2] + m_target_start + z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length] + mask[m_target_start : m_target_start + m_length] = edit_ratio + masks.append(mask) + masks = torch.stack(masks) + return masks + + +def process_prompts(prompts, num_loop): + ret_prompts = [] + for prompt in prompts: + if prompt.startswith("|0|"): + prompt_list = prompt.split("|")[1:] + text_list = [] + for i in range(0, len(prompt_list), 2): + start_loop = int(prompt_list[i]) + text = prompt_list[i + 1] + text = text_preprocessing(text) + end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + text_list.extend([text] * (end_loop - start_loop)) + assert len(text_list) == num_loop, f"Prompt loop mismatch: {len(text_list)} != {num_loop}" + ret_prompts.append(text_list) + else: + prompt = text_preprocessing(prompt) + ret_prompts.append([prompt] * num_loop) + return ret_prompts + + +def extract_json_from_prompts(prompts): + additional_infos = [] + ret_prompts = [] + for prompt in prompts: + parts = re.split(r"(?=[{\[])", prompt) + assert len(parts) <= 2, f"Invalid prompt: {prompt}" + ret_prompts.append(parts[0]) + if len(parts) == 1: + additional_infos.append({}) + else: + additional_infos.append(json.loads(parts[1])) + return ret_prompts, additional_infos + + +def main(): + # ====================================================== + # 1. cfg and init distributed env + # ====================================================== + cfg = parse_configs(training=False) + print(cfg) + + has_colossal = False + try: + import colossalai + from colossalai.cluster import DistCoordinator + except: + colossalai = None + has_colossal = False + + # init distributed + if os.environ.get("WORLD_SIZE", None) and has_colossal: + use_dist = True + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + if coordinator.world_size > 1: + set_sequence_parallel_group(dist.group.WORLD) + enable_sequence_parallelism = True + else: + enable_sequence_parallelism = False + else: + use_dist = False + enable_sequence_parallelism = False + + # ====================================================== + # 2. runtime variables + # ====================================================== + torch.set_grad_enabled(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = to_torch_dtype(cfg.dtype) + set_random_seed(seed=cfg.seed) + prompts = cfg.prompt + + # ====================================================== + # 3. build model & load weights + # ====================================================== + # 3.1. build model + input_size = (cfg.num_frames, *cfg.image_size) + vae = build_module(cfg.vae, MODELS) + latent_size = vae.get_latent_size(input_size) + text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32 + model = build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + dtype=dtype, + enable_sequence_parallelism=enable_sequence_parallelism, + ) + text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance + + # 3.2. move to device & eval + vae = vae.to(device, dtype).eval() + model = model.to(device, dtype).eval() + + # 3.3. build scheduler + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # 3.4. support for multi-resolution + model_args = dict() + if cfg.multi_resolution == "PixArtMS": + image_size = cfg.image_size + hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + model_args["data_info"] = dict(ar=ar, hw=hw) + elif cfg.multi_resolution == "STDiT2": + image_size = cfg.image_size + height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(cfg.batch_size) + width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size) + num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size) + ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size) + if cfg.num_frames == 1: + cfg.fps = IMG_FPS + fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size) + model_args["height"] = height + model_args["width"] = width + model_args["num_frames"] = num_frames + model_args["ar"] = ar + model_args["fps"] = fps + + # 3.5 reference + if cfg.reference_path is not None: + assert len(cfg.reference_path) == len( + prompts + ), f"Reference path mismatch: {len(cfg.reference_path)} != {len(prompts)}" + assert len(cfg.reference_path) == len( + cfg.mask_strategy + ), f"Mask strategy mismatch: {len(cfg.mask_strategy)} != {len(prompts)}" + else: + cfg.reference_path = [None] * len(prompts) + cfg.mask_strategy = [None] * len(prompts) + + # ====================================================== + # 4. inference + # ====================================================== + sample_idx = 0 + if cfg.sample_name is not None: + sample_name = cfg.sample_name + elif cfg.prompt_as_path: + sample_name = "" + else: + sample_name = "sample" + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + + # 4.1. batch generation + for i in range(0, len(prompts), cfg.batch_size): + batch_prompts_raw = prompts[i : i + cfg.batch_size] + batch_prompts_raw, additional_infos = extract_json_from_prompts(batch_prompts_raw) + batch_prompts_loops = process_prompts(batch_prompts_raw, cfg.loop) + # handle the last batch + if len(batch_prompts_raw) < cfg.batch_size and cfg.multi_resolution == "STDiT2": + model_args["height"] = model_args["height"][: len(batch_prompts_raw)] + model_args["width"] = model_args["width"][: len(batch_prompts_raw)] + model_args["num_frames"] = model_args["num_frames"][: len(batch_prompts_raw)] + model_args["ar"] = model_args["ar"][: len(batch_prompts_raw)] + model_args["fps"] = model_args["fps"][: len(batch_prompts_raw)] + + # 4.2. load reference videos & images + for j, info in enumerate(additional_infos): + if "reference_path" in info: + cfg.reference_path[i + j] = info["reference_path"] + if "mask_strategy" in info: + cfg.mask_strategy[i + j] = info["mask_strategy"] + refs_x = collect_references_batch(cfg.reference_path[i : i + cfg.batch_size], vae, cfg.image_size) + mask_strategy = cfg.mask_strategy[i : i + cfg.batch_size] + + # 4.3. diffusion sampling + old_sample_idx = sample_idx + # generate multiple samples for each prompt + for k in range(cfg.num_sample): + sample_idx = old_sample_idx + video_clips = [] + + # 4.4. long video generation + for loop_i in range(cfg.loop): + # 4.4 sample in hidden space + batch_prompts = [prompt[loop_i] for prompt in batch_prompts_loops] + + # 4.5. apply mask strategy + masks = None + # if cfg.reference_path is not None: + if loop_i > 0: + ref_x = vae.encode(video_clips[-1]) + for j, refs in enumerate(refs_x): + if refs is None: + refs_x[j] = [ref_x[j]] + else: + refs.append(ref_x[j]) + if mask_strategy[j] is None: + mask_strategy[j] = "" + else: + mask_strategy[j] += ";" + mask_strategy[ + j + ] += f"{loop_i},{len(refs)-1},-{cfg.condition_frame_length},0,{cfg.condition_frame_length}" + + # sampling + z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype) + masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i) + samples = scheduler.sample( + model, + text_encoder, + z=z, + prompts=batch_prompts, + device=device, + additional_args=model_args, + mask=masks, # scheduler must support mask + ) + samples = vae.decode(samples.to(dtype)) + video_clips.append(samples) + + # 4.7. save video + if loop_i == cfg.loop - 1: + if not use_dist or coordinator.is_master(): + for idx in range(len(video_clips[0])): + video_clips_i = [video_clips[0][idx]] + [ + video_clips[i][idx][:, cfg.condition_frame_length :] for i in range(1, cfg.loop) + ] + video = torch.cat(video_clips_i, dim=1) + print(f"Prompt: {batch_prompts_raw[idx]}") + if cfg.prompt_as_path: + sample_name_suffix = batch_prompts_raw[idx] + else: + sample_name_suffix = f"_{sample_idx}" + save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}") + if cfg.num_sample != 1: + save_path = f"{save_path}-{k}" + save_sample(video, fps=cfg.fps // cfg.frame_interval, save_path=save_path) + sample_idx += 1 + + +if __name__ == "__main__": + main() diff --git a/src/videogen_hub/pipelines/opensora/scripts/inference.py b/src/videogen_hub/pipelines/opensora/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..023bb945e4f5e30ed8db5e39a3e9e695ad57892d --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/scripts/inference.py @@ -0,0 +1,213 @@ +import os + +import torch +import torch.distributed as dist +from mmengine.runner import set_random_seed + +from videogen_hub.pipelines.opensora.opensora.acceleration.parallel_states import set_sequence_parallel_group +from videogen_hub.pipelines.opensora.opensora.datasets import IMG_FPS, save_sample +from videogen_hub.pipelines.opensora.opensora.models.text_encoder.t5 import text_preprocessing +from videogen_hub.pipelines.opensora.opensora.registry import MODELS, SCHEDULERS, build_module +from videogen_hub.pipelines.opensora.opensora.utils.config_utils import parse_configs +from videogen_hub.pipelines.opensora.opensora.utils.misc import to_torch_dtype + +try: + import colossalai + from colossalai.cluster import DistCoordinator +except ImportError: + colossalai = None + + +def main(config=None): + # ====================================================== + # 1. cfg and init distributed env + # ====================================================== + cfg = config + if cfg is None: + cfg = parse_configs(training=False) + print(cfg) + + # init distributed + if os.environ.get("WORLD_SIZE", None) and colossalai is not None: + use_dist = True + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + if coordinator.world_size > 1: + set_sequence_parallel_group(dist.group.WORLD) + enable_sequence_parallelism = True + else: + enable_sequence_parallelism = False + else: + use_dist = False + enable_sequence_parallelism = False + + # ====================================================== + # 2. runtime variables + # ====================================================== + torch.set_grad_enabled(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = to_torch_dtype(cfg.dtype) + set_random_seed(seed=cfg.seed) + prompts = cfg.prompt + + # ====================================================== + # 3. build model & load weights + # ====================================================== + # 3.1. build model + input_size = (cfg.num_frames, *cfg.image_size) + vae = build_module(cfg.vae, MODELS) + latent_size = vae.get_latent_size(input_size) + text_encoder = build_module( + cfg.text_encoder, MODELS, device=device + ) # T5 must be fp32 + model = build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + dtype=dtype, + enable_sequence_parallelism=enable_sequence_parallelism, + ) + text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance + + # 3.2. move to device & eval + vae = vae.to(device, dtype).eval() + model = model.to(device, dtype).eval() + + # 3.3. build scheduler + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # 3.4. support for multi-resolution + model_args = dict() + if cfg.multi_resolution == "PixArtMS": + image_size = cfg.image_size + hw = torch.tensor([image_size], device=device, dtype=dtype).repeat( + cfg.batch_size, 1 + ) + ar = torch.tensor( + [[image_size[0] / image_size[1]]], device=device, dtype=dtype + ).repeat(cfg.batch_size, 1) + model_args["data_info"] = dict(ar=ar, hw=hw) + elif cfg.multi_resolution == "STDiT2": + image_size = cfg.image_size + height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat( + cfg.batch_size + ) + width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat( + cfg.batch_size + ) + num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat( + cfg.batch_size + ) + ar = torch.tensor( + [image_size[0] / image_size[1]], device=device, dtype=dtype + ).repeat(cfg.batch_size) + if cfg.num_frames == 1: + cfg.fps = IMG_FPS + fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size) + model_args["height"] = height + model_args["width"] = width + model_args["num_frames"] = num_frames + model_args["ar"] = ar + model_args["fps"] = fps + + # ====================================================== + # 4. inference + # ====================================================== + sample_idx = 0 + if cfg.sample_name is not None: + sample_name = cfg.sample_name + elif cfg.prompt_as_path: + sample_name = "" + else: + sample_name = "sample" + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + + all_batch_samples = [] + # 4.1. batch generation + for i in range(0, len(prompts), cfg.batch_size): + # 4.2 sample in hidden space + batch_prompts_raw = prompts[i: i + cfg.batch_size] + batch_prompts = [text_preprocessing(prompt) for prompt in batch_prompts_raw] + # handle the last batch + if len(batch_prompts_raw) < cfg.batch_size and cfg.multi_resolution == "STDiT2": + model_args["height"] = model_args["height"][: len(batch_prompts_raw)] + model_args["width"] = model_args["width"][: len(batch_prompts_raw)] + model_args["num_frames"] = model_args["num_frames"][ + : len(batch_prompts_raw) + ] + model_args["ar"] = model_args["ar"][: len(batch_prompts_raw)] + model_args["fps"] = model_args["fps"][: len(batch_prompts_raw)] + + all_samples = [] + # 4.3. diffusion sampling + old_sample_idx = sample_idx + # generate multiple samples for each prompt + for k in range(cfg.num_sample): + sample_idx = old_sample_idx + + # Skip if the sample already exists + # This is useful for resuming sampling VBench + if cfg.prompt_as_path: + skip = True + for batch_prompt in batch_prompts_raw: + path = os.path.join(save_dir, f"{sample_name}{batch_prompt}") + if cfg.num_sample != 1: + path = f"{path}-{k}" + path = f"{path}.mp4" + if not os.path.exists(path): + skip = False + break + if skip: + continue + + # sampling + z = torch.randn( + len(batch_prompts), + vae.out_channels, + *latent_size, + device=device, + dtype=dtype, + ) + samples = scheduler.sample( + model, + text_encoder, + z=z, + prompts=batch_prompts, + device=device, + additional_args=model_args, + ) + samples = vae.decode(samples.to(dtype), model_args["num_frames"]) + + # 4.4. save samples + if not use_dist or coordinator.is_master(): + for idx, sample in enumerate(samples): + print(f"Prompt: {batch_prompts_raw[idx]}") + if cfg.prompt_as_path: + sample_name_suffix = batch_prompts_raw[idx] + else: + sample_name_suffix = f"_{sample_idx}" + save_path = os.path.join( + save_dir, f"{sample_name}{sample_name_suffix}" + ) + if cfg.num_sample != 1: + save_path = f"{save_path}-{k}" + # save_sample( + # sample, fps=cfg.fps, save_path=save_path + # ) + sample_idx += 1 + + all_samples.append(samples) + all_batch_samples.append(all_samples) + + return all_batch_samples + + +if __name__ == "__main__": + main() diff --git a/src/videogen_hub/pipelines/opensora/tools/__init__.py b/src/videogen_hub/pipelines/opensora/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/README.md b/src/videogen_hub/pipelines/opensora/tools/caption/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b30856cbd56ca96c630b0d96ba114235ea73a5d2 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/caption/README.md @@ -0,0 +1,95 @@ +# Video Captioning + +Human labeling of videos is expensive and time-consuming. We adopt powerful image captioning models to generate captions for videos. Although GPT-4V achieves a better performance, its 20s/sample speed is too slow for us. LLaVA is the second best open-source model in [MMMU](https://mmmu-benchmark.github.io/) and accepts any resolution. We find the quality of 34B model is comparable. + +![Caption](https://i0.imgs.ovh/2024/03/16/eXdvC.png) + +## LLaVA Captioning + +We extract three frames from the video for captioning. With batch inference, we can achieve 10 times speedup. With approximatly 720p resolution and 3 frames, the speed is 2~3 videos/s on 8 GPUs. If we resize the smaller side to 336, the speed can be 8 videos/s. + +### Requirement + +```bash +# create conda env +conda create -n llava python=3.10 -y +conda activate llava + +# install torch +pip install torch torchvision + +# clone llava +git clone https://github.com/haotian-liu/LLaVA.git +cd LLaVA +# CAUTION: This line is to remove torch dependency in pyproject.toml, which is: +# "torch==2.1.2", "torchvision==0.16.2", +# It is better manually remove it in your local pyproject.toml +sed -i '16d' pyproject.toml + +# install llava +pip install --upgrade pip # enable PEP 660 support +pip install -e . + +# install flash attention +pip install flash-attn --no-build-isolation +# install colossalai and decord +pip install colossalai decord +``` + +Since only the 34B model's performance is comparable to GPT-4V, we only provide the usage of the 34B model. The 34B model is available [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b), or run our script and it will be downloaded automatically. + +### Usage + +Prepare a csv file for processing. The csv file can be generated by `convert_dataset.py` according to its [documentation](/tools/datasets/README.md). Then, run the following command to generate captions for videos/images with LLaVA: + +```bash +# we run this on 8xH800 GPUs +torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava DATA.csv --tp-size 2 --dp-size 4 --bs 16 + +# at least two 80G GPUs are required +torchrun --nproc_per_node 2 --standalone -m tools.caption.caption_llava DATA.csv --tp-size 2 --dp-size 1 --bs 16 + +# can also caption images +torchrun --nproc_per_node 2 --standalone -m tools.caption.caption_llava DATA.csv --tp-size 2 --dp-size 1 --bs 16 --prompt image-3ex + +# caption with llava-34B +# NOTE: remember to enable flash attention for this model +torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava DATA.csv --dp-size 4 --tp-size 2 --model-path liuhaotian/llava-v1.6-34b --prompt image-3ex --flash-attention + +# caption with mistral-7B +torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava DATA.csv --dp-size 8 --tp-size 1 --model-path liuhaotian/llava-v1.6-mistral-7b --prompt video +# bs can be 48 +``` + +Please note that you should add the `--flash-attention` flag when running with Llama-based Llava models as it provides speedup but do turn it off for mistral-based ones. Reasons can be found in [this issue](/static-proxy?url=https%3A%2F%2Fdiscuss.huggingface.co%2Ft%2Fflash-attention-has-no-effect-on-inference%2F73453). + +After running the script, with `dp-size=N`, you will get `N` parts of csv files. Run the following command to merge them: + +```bash +python -m tools.datasets.datautil DATA_caption_part*.csv --output DATA_caption.csv +``` + +### Resume + +Sometimes the process may be interrupted. We can resume the process by running the following command: + +```bash +# merge generated results +python -m tools.datasets.datautil DATA_caption_part*.csv --output DATA_caption.csv + +# get the remaining videos +python -m tools.datasets.datautil DATA.csv --difference DATA_caption.csv --output DATA_remaining.csv +``` + +Then use the output csv file to resume the process. + +## GPT-4V Captioning + +Run the following command to generate captions for videos with GPT-4V: + +```bash +# output: DATA_caption.csv +python -m tools.caption.caption_gpt4 DATA.csv --key $OPENAI_API_KEY +``` + +The cost is approximately $0.01 per video (3 frames per video). diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/__init__.py b/src/videogen_hub/pipelines/opensora/tools/caption/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/__init__.py b/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/__init__.py b/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/policies/__init__.py b/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/policies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..35998d404993d8c5073a3f6796c161402fdd26c4 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/policies/__init__.py @@ -0,0 +1,2 @@ +from .llama import LlavaLlamaForCausalLMPolicy +from .mistral import LlavaMistralForCausalLMPolicy diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/policies/llama.py b/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/policies/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..dff8f01d68ffd672384f61f2b9c2ce0011b3e556 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/policies/llama.py @@ -0,0 +1,98 @@ +from typing import Dict, Union + +import torch.nn as nn +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["LlavaLlamaPolicy", "LlavaLlamaForCausalLMPolicy"] + + +class LlavaLlamaPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + self.model.config.vocab_size + self.shard_config.tensor_parallel_size + + # if vocab_size % world_size != 0: + # new_vocab_size = vocab_size + world_size - vocab_size % world_size + # self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.llama.modeling_llama import LlamaDecoderLayer + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), + ], + ) + + return policy + + def postprocess(self): + return self.model + + +class LlavaLlamaForCausalLMPolicy(LlavaLlamaPolicy): + def module_policy(self): + from transformers import LlamaForCausalLM + + policy = super().module_policy() + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True} + ) + ], + ) + } + policy.update(new_item) + return policy diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/policies/mistral.py b/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/policies/mistral.py new file mode 100644 index 0000000000000000000000000000000000000000..0afea570af861d170f4334529e14c43f9a32b542 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/caption/acceleration/llava/policies/mistral.py @@ -0,0 +1,113 @@ +import warnings +from typing import Dict, Union + +import torch.nn as nn +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["LlavaMistralPolicy", "LlavaMistralForCausalLMPolicy"] + + +class LlavaMistralPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralModel + + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn( + "Mistral doesn't support sequence parallelism now, will ignore the sequence parallelism flag." + ) + + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, + } + + policy[MistralDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), + ], + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=MistralModel, + ) + + return policy + + def postprocess(self): + return self.model + + +class LlavaMistralForCausalLMPolicy(LlavaMistralPolicy): + def module_policy(self): + from transformers import MistralForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + MistralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + return policy diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/camera_motion_detect.py b/src/videogen_hub/pipelines/opensora/tools/caption/camera_motion_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0077c65c254bf0f1a73b11883f9eccd681e792 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/caption/camera_motion_detect.py @@ -0,0 +1,132 @@ +# ref: https://github.com/antiboredom/camera-motion-detector + +import argparse + +import cv2 +import numpy as np +import pandas as pd +from tqdm import tqdm + +tqdm.pandas() + + +def apply(df, func, **kwargs): + if pandas_has_parallel: + return df.parallel_apply(func, **kwargs) + return df.progress_apply(func, **kwargs) + + +try: + from pandarallel import pandarallel + + pandarallel.initialize(progress_bar=True) + pandas_has_parallel = True +except ImportError: + pandas_has_parallel = False + + +def make_empty(new_w, new_h): + empty = [] + for y in range(new_h): + xvals = [] + for x in range(new_w): + xvals.append([x, y]) + empty.append(xvals) + + empty = np.array(empty) + return empty + + +def get_type(mag, ang, zoom_in, tau_static=1.0, tau_zoom=(0.4, 0.6)): + if mag < tau_static: + return "static" + if zoom_in < tau_zoom[0]: + return "zoom out" + if zoom_in > tau_zoom[1]: + return "zoom in" + if ang < 45 or ang >= 315: + return "pan left" + if 45 <= ang < 135: + return "tilt up" + if 135 <= ang < 225: + return "pan right" + if 225 <= ang < 315: + return "tilt down" + return "unknown" + + +def get_video_type(frame_types): + # count the number of each type + counts = {} + max_count = 0 + max_type = None + for frame_type in frame_types: + if frame_type not in counts: + counts[frame_type] = 0 + counts[frame_type] += 1 + if counts[frame_type] > max_count: + max_count = counts[frame_type] + max_type = frame_type + if max_count > len(frame_types) / 2: + return max_type + if "static" in counts: + return "unknown" + if "zoom in" not in counts and "zoom out" not in counts: + return "pan/tilt" + return "dynamic" + + +def process(path: str, frame_interval=15) -> str: + cap = cv2.VideoCapture(path) + count = 0 + prvs = None + frame_types = [] + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + if count == 0: + prvs = frame + h, w = frame.shape + empty = make_empty(w, h) + empty_dists = np.sqrt( + np.square(empty.ravel()[::2] - (w / 2)) + np.square(empty.ravel()[1::2] - (h / 2)) + ) + else: + flow = cv2.calcOpticalFlowFarneback(prvs, frame, None, 0.5, 3, 15, 3, 5, 1.2, 0) + mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True) + mean_mag = np.median(mag) + mean_ang = np.median(ang) + + flow_coords = flow + empty + xvals = flow_coords.ravel()[::2] - (w / 2) + yvals = flow_coords.ravel()[1::2] - (h / 2) + dists = np.sqrt(np.square(xvals) + np.square(yvals)) + dist_diff = dists >= empty_dists + zoom_in_factor = np.count_nonzero(dist_diff) / len(dist_diff) + frame_types.append(get_type(mean_mag, mean_ang, zoom_in_factor)) + count += frame_interval + cap.set(cv2.CAP_PROP_POS_FRAMES, count) + else: + cap.release() + break + video_type = get_video_type(frame_types) + return video_type + + +def main(args): + output_file = args.input.replace(".csv", "_cmotion.csv") + data = pd.read_csv(args.input) + data["cmotion"] = apply(data["path"], process) + data.to_csv(output_file, index=False) + print(f"Output saved to {output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str) + parser.add_argument("--disable-parallel", action="store_true") + args = parser.parse_args() + if args.disable_parallel: + pandas_has_parallel = False + main(args) diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/caption_gpt4.py b/src/videogen_hub/pipelines/opensora/tools/caption/caption_gpt4.py new file mode 100644 index 0000000000000000000000000000000000000000..f22c296ea5130a1ba2606bc885109c1851c56363 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/caption/caption_gpt4.py @@ -0,0 +1,91 @@ +import argparse +import base64 +import csv +import os +from io import BytesIO + +import requests +import tqdm + +from .utils import IMG_EXTENSIONS, PROMPTS, VID_EXTENSIONS, VideoTextDataset + + +def to_base64(image): + buffer = BytesIO() + image.save(buffer, format="JPEG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +def get_caption(frame, prompt, api_key): + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + payload = { + "model": "gpt-4-vision-preview", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt, + }, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[0]}"}}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[1]}"}}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[2]}"}}, + ], + } + ], + "max_tokens": 300, + } + response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=60) + caption = response.json()["choices"][0]["message"]["content"] + caption = caption.replace("\n", " ") + return caption + + +def main(args): + # ====================================================== + # 1. read video list + # ====================================================== + dataset = VideoTextDataset(args.input) + output_file = os.path.splitext(args.input)[0] + "_caption.csv" + f = open(output_file, "w") + writer = csv.writer(f) + writer.writerow(["video", "text"]) + + # make sure that the prompt type matches the data type + data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1] + prompt_type = PROMPTS[args.prompt]["type"] + if prompt_type == "image": + assert ( + data_extension.lower() in IMG_EXTENSIONS + ), "The prompt is suitable for an image dataset but the data is not image." + elif prompt_type == "video": + assert ( + data_extension.lower() in VID_EXTENSIONS + ), "The prompt is suitable for a video dataset but the data is not video." + else: + raise ValueError(f"Found invalid prompt type {prompt_type}") + + # ====================================================== + # 2. generate captions + # ====================================================== + for sample in tqdm.tqdm(dataset): + prompt = PROMPTS[args.prompt]["text"] + if "text" in args.prompt: + prompt = prompt.format(sample["text"]) + frames = sample["image"] + frames = [to_base64(frame) for frame in frames] + caption = get_caption(frames, prompt, args.key) + + writer.writerow((sample["path"], caption)) + f.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, help="Path to the input CSV file") + parser.add_argument("--prompt", type=str, default="video-f3-detail-3ex") + parser.add_argument("--key", type=str) + args = parser.parse_args() + + main(args) diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/caption_llava.py b/src/videogen_hub/pipelines/opensora/tools/caption/caption_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..0134ec51ff13e847625616755f47cc0455c6d9a8 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/caption/caption_llava.py @@ -0,0 +1,344 @@ +import argparse +import csv +import time +import warnings +from datetime import timedelta + +import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator, ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.utils import get_current_device, set_seed +from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX +from llava.conversation import conv_templates +from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from .acceleration.llava.policies import LlavaLlamaForCausalLMPolicy, LlavaMistralForCausalLMPolicy +from .utils import IMG_EXTENSIONS, PROMPTS, VID_EXTENSIONS, Timer, VideoTextDataset, collate_fn + +disable_torch_init() + + +class NoPaddingDistributedSampler(DistributedSampler): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False): + super().__init__( + dataset=dataset, num_replicas=num_replicas, rank=rank, seed=seed, shuffle=False, drop_last=False + ) + remainder = len(self.dataset) % self.num_replicas + if remainder > 0 and (self.rank + 1) - remainder <= 0: + # if the dataset is not divisible by num_replicas + # the remaining items will be allocated to the first n ranks + self.num_samples = len(self.dataset) // self.num_replicas + 1 + else: + self.num_samples = len(self.dataset) // self.num_replicas + self.total_size = len(dataset) + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + return iter(indices) + + +@torch.inference_mode() +def main(args): + # ====================================================== + # 1. init environment + # ====================================================== + # we set a very large timeout to avoid some processes exit early + dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + set_seed(1024) + coordinator = DistCoordinator() + + # prepare the dp and tp groups + assert ( + args.dp_size * args.tp_size == coordinator.world_size + ), f"DP size {args.dp_size} * TP size {args.tp_size} must equal to world size {coordinator.world_size}" + mesh = ProcessGroupMesh(args.dp_size, args.tp_size) + dp_group = mesh.get_group_along_axis(0) + tp_group = mesh.get_group_along_axis(1) + + # ====================================================== + # 2. load model + # ====================================================== + model_path = args.model_path + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # Pytorch non-meta copying warning fills out the console + tokenizer, model, image_processor, context_len = load_pretrained_model( + model_path=model_path, + model_base=None, + model_name=get_model_name_from_path(model_path), + device=get_current_device(), + torch_dtype=torch.float16, + attn_implementation="flash_attention_2" if args.flash_attention else "eager", + ) + dist.barrier() + + # ====================================================== + # 3. Apply system optimization + # ====================================================== + tp_size = dist.get_world_size(tp_group) + shard_config = ShardConfig( + tensor_parallel_process_group=tp_group if tp_size > 1 else None, + enable_tensor_parallelism=True if tp_size > 1 else False, + ) + shard_former = ShardFormer(shard_config=shard_config) + + # check the model type + model_name = model.__class__.__name__ + print(model_name) + if model_name == "LlavaLlamaForCausalLM": + model = shard_former.optimize(model, policy=LlavaLlamaForCausalLMPolicy())[0].cuda() + elif model_name == "LlavaMistralForCausalLM": + model = shard_former.optimize(model, policy=LlavaMistralForCausalLMPolicy())[0].cuda() + else: + print(f"The shardformer policy for {model_name} is not implemented, skip") + torch.cuda.empty_cache() + + # ====================================================== + # 4. Prepare dataloader + # ====================================================== + # prepare prompt + query = PROMPTS[args.prompt]["text"] + if dist.get_rank() == 0: + print(f"Prompt: {query}") + + if "text" in args.prompt: + + def get_text_input_ids(text): + conv = conv_templates["chatml_direct"].copy() + query_text = query.format(text) + conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query_text) + prompt = conv.get_prompt() + # add num_frames images + t = prompt.split("") + prompt = t[0] + "" * args.num_frames + t[1] + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") + input_ids = input_ids.unsqueeze(0) + return input_ids + + else: + conv = conv_templates["chatml_direct"].copy() + conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query) + prompt = conv.get_prompt() + # add num_frames images + t = prompt.split("") + prompt = t[0] + "" * args.num_frames + t[1] + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") + input_ids = input_ids.unsqueeze(0) + + def get_text_input_ids(*args): + return input_ids + + # build dataset + def transform(imgs): + imgs = process_images(imgs, image_processor, model.config) + imgs = imgs.to(dtype=torch.float16) + return imgs + + dataset = VideoTextDataset( + args.input, + transform=transform, + num_frames=args.num_frames, + get_text_input_ids=get_text_input_ids, + resize=args.resize, + ) + + # make sure that the prompt type matches the data type + data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1] + prompt_type = PROMPTS[args.prompt]["type"] + if prompt_type == "image": + assert ( + data_extension.lower() in IMG_EXTENSIONS + ), f"The prompt is suitable for an image dataset but the data is not image. The first data is of format {data_extension}" + elif prompt_type == "video": + assert ( + data_extension.lower() in VID_EXTENSIONS + ), f"The prompt is suitable for a video dataset but the data is not video. The first data is of format {data_extension}" + else: + raise ValueError(f"Found invalid prompt type {prompt_type}") + + total_num_videos = len(dataset) + + # build sampler + dp_rank = dist.get_rank(dp_group) + dp_size = dist.get_world_size(dp_group) + sampler = NoPaddingDistributedSampler(dataset, rank=dp_rank, num_replicas=dp_size) + + # build dataloader + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.bs, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + prefetch_factor=args.prefetch_factor, + sampler=sampler, + collate_fn=collate_fn, + ) + + # prepare output file reader + output_file = args.input.replace(".csv", "_caption.csv") + + # create csv writer + has_dp_writter = dist.get_rank(tp_group) == 0 + + if has_dp_writter: + # the dp writer takes care of the files processed on the current dp rank + # so we use write mode + output_file_split = output_file.replace(".csv", f"_part{dp_rank}.csv") + dp_file = open(output_file_split, "w") + dp_writer = csv.writer(dp_file) + dp_writer.writerow(["path", "text", "num_frames"]) + + # ====================================================== + # 5. generate captions + # ====================================================== + if dist.get_rank(tp_group) == 0: + pbar = tqdm(dataloader, position=dp_rank, desc=f"Data Parallel Rank {dist.get_rank(dp_group)}") + else: + pbar = dataloader + + if args.profile: + encode_time = [] + generate_time = [] + output_length = [] + total_time = [] + + for i, batch in enumerate(pbar): + # measure time + if args.profile: + torch.cuda.synchronize() + start_time = time.time() + + video_files, frames, video_lengths, img_size_list, texts = batch + + # encode the batch of inputs + with Timer() as encode_timer: + samples = [] + for imgs, imgs_size, input_ids in zip(frames, img_size_list, texts): + imgs = imgs.cuda() + input_ids = input_ids.cuda() + _, _, _, _, inputs_embeds, _ = model.prepare_inputs_labels_for_multimodal( + input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size + ) + samples.append(inputs_embeds) + + # padding + max_len = max([sample.shape[1] for sample in samples]) + attention_mask = torch.tensor( + [[0] * (max_len - samples[i].shape[1]) + [1] * samples[i].shape[1] for i in range(len(samples))] + ).to(model.device) + inputs_embeds = [ + torch.cat( + [ + torch.zeros( + (1, max_len - samples[i].shape[1], samples[i].shape[-1]), + device=model.device, + dtype=torch.float16, + ), + samples[i], + ], + dim=1, + ) + for i in range(len(samples)) + ] + inputs_embeds = torch.cat(inputs_embeds, dim=0) + + # generate outputs + with Timer() as generate_timer: + output_ids = super(type(model), model).generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=False, # sampling is not deterministic and may cause TP to hang + max_new_tokens=args.max_tokens, + use_cache=True, + ) + + # skip warmup and add profiling data + if args.profile and i >= args.profile_warmup: + output_length.append(output_ids.size(0) * output_ids.size(1)) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + outputs = [output.replace("\n", " ").strip() for output in outputs] + + # skip warmup and add profiling data + if args.profile and i >= args.profile_warmup: + # measure time + torch.cuda.synchronize() + time_taken = time.time() - start_time + + total_time.append(time_taken) + encode_time.append(encode_timer.time_taken) + generate_time.append(generate_timer.time_taken) + + # save results + if has_dp_writter: + result = list(zip(video_files, outputs, video_lengths)) + for t in result: + dp_writer.writerow(t) + + # display profiling info + if args.profile: + print(output_length) + num_samples_after_warmup = total_num_videos - args.bs * args.profile_warmup * dp_size + print(f"throughput (samples/s): {num_samples_after_warmup / sum(total_time)}") + print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}") + print(f"average generate time per sample: {sum(generate_time) / num_samples_after_warmup}") + print(f"average number of tokens characters per sample: {sum(output_length) / num_samples_after_warmup}") + print(f"Max GPU allocated / GB: {torch.cuda.max_memory_allocated() / 1024**3}") + print(f"Max GPU reserved / GB: {torch.cuda.max_memory_reserved() / 1024**3}") + + # ====================================================== + # 6. shutdown + # ====================================================== + # close file writing + if has_dp_writter: + dp_file.close() + dist.barrier() + + # terminate distributed env + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, help="Path to the input CSV file") + parser.add_argument("--model-path", type=str, default="liuhaotian/llava-v1.6-34b") + parser.add_argument("--prompt", type=str, default="video-f1-detail-3ex") + parser.add_argument("--resize", type=int, default=336) + parser.add_argument("--num-frames", type=int, default=1) + parser.add_argument("--max-tokens", type=int, default=300) + # speed related + parser.add_argument("--bs", type=int, default=16) + parser.add_argument("--tp-size", type=int, default=2) + parser.add_argument("--dp-size", type=int, default=4) + parser.add_argument("--num-workers", type=int, default=8) + parser.add_argument("--prefetch-factor", type=int, default=8, help="Prefetch factor") + parser.add_argument( + "--flash-attention", + action="store_true", + help="Whether to use flash attention. You can turn on this flag for llama model and off for mistral model.", + ) + # debug related + parser.add_argument("--profile", action="store_true") + parser.add_argument("--profile-warmup", type=int, default=1) + + args = parser.parse_args() + main(args) diff --git a/src/videogen_hub/pipelines/opensora/tools/caption/utils.py b/src/videogen_hub/pipelines/opensora/tools/caption/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f50cdc271c036e8fdcb2c69f80fb39634ea94675 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/caption/utils.py @@ -0,0 +1,140 @@ +import time + +import pandas as pd +import torch +import torchvision.transforms as transforms +from torchvision.datasets.folder import pil_loader + +from tools.datasets.utils import extract_frames, is_video + +PROMPTS = { + "image": { + "text": "Describe this image and its style to generate a succinct yet informative description. Pay attention to all objects in the image. The description should be useful for AI to re-generate the image. The description should be no more than five sentences. Remember do not exceed 5 sentences.", + "type": "image", + }, + "image-text": { + "text": "Describe this image and its style in a very detailed manner. Pay attention to all objects in the image. The description should be useful for AI to re-generate the image. The description should be no more than six sentences. Some information about the image is '{}'.", + "type": "image", + }, + "image-3ex": { + "text": "An image is given. Describe this image and its style to generate a succinct yet informative description. Pay attention to all objects in the image. The description should be useful for AI to re-generate the video. The description should be no more than five sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick and walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", + "type": "image", + }, + "video": { + "text": "Describe this video and its style in a very detailed manner. Pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences.", + "type": "video", + }, + "video-text": { + "text": "Describe this video and its style in a very detailed manner. Some information about the image is '{}'. Pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences.", + "type": "video", + }, + "video-f1-detail-3ex": { + "text": "A video is given by providing the middle frame. Describe this video and its style to generate a description. Pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", + "type": "video", + }, + "video-f1-detail-2ex-text": { + "text": "A video is given by providing the middle frame. Some information about the image is '{}'. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", + "type": "video", + }, + "video-f3-detail-3ex": { + "text": "A video is given by providing three frames in chronological order. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", + "type": "video", + }, + "video-f3-detail-2ex-text": { + "text": "A video is given by providing three frames in chronological order. Some information about the image is '{}'. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", + "type": "video", + }, +} + + +NUM_FRAMES_POINTS = { + 1: (0.5,), + 2: (0.25, 0.75), + 3: (0.1, 0.5, 0.9), +} + + +def read_file(input_path): + if input_path.endswith(".csv"): + return pd.read_csv(input_path) + elif input_path.endswith(".parquet"): + return pd.read_parquet(input_path) + else: + raise NotImplementedError(f"Unsupported file format: {input_path}") + + +class VideoTextDataset(torch.utils.data.Dataset): + def __init__(self, csv_path, transform=None, num_frames=3, get_text_input_ids=None, resize=None): + self.csv_path = csv_path + self.transform = transform + self.data = read_file(csv_path) + self.points = NUM_FRAMES_POINTS[num_frames] + self.get_text_input_ids = get_text_input_ids + self.use_text = False + self.resize_size = resize + self.resize = transforms.Resize(resize, transforms.InterpolationMode.BICUBIC) if resize is not None else None + if "text" in self.data.columns: + self.use_text = True + + def getitem(self, index): + sample = self.data.iloc[index] + path = sample["path"] + if not is_video(path): + images = [pil_loader(path)] + length = 1 + else: + images, length = extract_frames(sample["path"], points=self.points, backend="opencv", return_length=True) + if self.resize_size is not None: + images_r = [] + for img in images: + if img.size[0] > self.resize_size or img.size[1] > self.resize_size: + img = self.resize(img) + images_r.append(img) + images = images_r + imgs_size = [img.size for img in images] + if self.transform is not None: + images = self.transform(images) + + # we put images into a list as pytorch dataloader does not accept Pill + out = dict(path=path, image=images, length=length, img_size=imgs_size) + if self.get_text_input_ids is not None: + if self.use_text: + out["text"] = self.get_text_input_ids(sample["text"]) + else: + out["text"] = self.get_text_input_ids() + else: + if self.use_text: + out["text"] = sample["text"] + else: + out["text"] = "" + return out + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.getitem(index) + + +def collate_fn(batch): + paths = [item["path"] for item in batch] + images = [item["image"] for item in batch] + lengths = [item["length"] for item in batch] + img_sizes = [item["img_size"] for item in batch] + texts = [item["text"] for item in batch] + return paths, images, lengths, img_sizes, texts + + +class Timer: + def __init__(self): + self.time_taken = 0 + self.start_time = 0 + self.end_time = 0 + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.end_time = time.time() + self.time_taken = self.end_time - self.start_time diff --git a/src/videogen_hub/pipelines/opensora/tools/datasets/README.md b/src/videogen_hub/pipelines/opensora/tools/datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ac14f3337696b8134428bd193502fd5933d23773 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/datasets/README.md @@ -0,0 +1,281 @@ +# Dataset Management + +- [Dataset Management](#dataset-management) + - [Dataset Format](#dataset-format) + - [Dataset to CSV](#dataset-to-csv) + - [Manage datasets](#manage-datasets) + - [Requirement](#requirement) + - [Basic Usage](#basic-usage) + - [Score filtering](#score-filtering) + - [Documentation](#documentation) + - [Transform datasets](#transform-datasets) + - [Resize](#resize) + - [Frame extraction](#frame-extraction) + - [Crop Midjourney 4 grid](#crop-midjourney-4-grid) + - [Analyze datasets](#analyze-datasets) + - [Data Process Pipeline](#data-process-pipeline) + +After preparing the raw dataset according to the [instructions](/docs/datasets.md), you can use the following commands to manage the dataset. + +## Dataset Format + +All dataset should be provided in a `.csv` file (or `parquet.gzip` to save space), which is used for both training and data preprocessing. The columns should follow the words below: + +- `path`: the relative/absolute path or url to the image or video file. Required. +- `text`: the caption or description of the image or video. Required for training. +- `num_frames`: the number of frames in the video. Required for training. +- `width`: the width of the video frame. Required for dynamic bucket. +- `height`: the height of the video frame. Required for dynamic bucket. +- `aspect_ratio`: the aspect ratio of the video frame (height / width). Required for dynamic bucket. +- `resolution`: height x width. For analysis. +- `text_len`: the number of tokens in the text. For analysis. +- `aes`: aesthetic score calculated by [asethetic scorer](/tools/aesthetic/README.md). For filtering. +- `flow`: optical flow score calculated by [UniMatch](/tools/scoring/README.md). For filtering. +- `match`: matching score of a image-text/video-text pair calculated by [CLIP](/tools/scoring/README.md). For filtering. +- `fps`: the frame rate of the video. Optional. +- `cmotion`: the camera motion. + +An example ready for training: + +```csv +path, text, num_frames, width, height, aspect_ratio +/absolute/path/to/image1.jpg, caption, 1, 720, 1280, 0.5625 +/absolute/path/to/video1.mp4, caption, 120, 720, 1280, 0.5625 +/absolute/path/to/video2.mp4, caption, 20, 256, 256, 1 +``` + +We use pandas to manage the `.csv` or `.parquet` files. The following code is for reading and writing files: + +```python +df = pd.read_csv(input_path) +df = df.to_csv(output_path, index=False) +# or use parquet, which is smaller +df = pd.read_parquet(input_path) +df = df.to_parquet(output_path, index=False) +``` + +## Dataset to CSV + +As a start point, `convert.py` is used to convert the dataset to a CSV file. You can use the following commands to convert the dataset to a CSV file: + +```bash +python -m tools.datasets.convert DATASET-TYPE DATA_FOLDER + +# general video folder +python -m tools.datasets.convert video VIDEO_FOLDER --output video.csv +# general image folder +python -m tools.datasets.convert image IMAGE_FOLDER --output image.csv +# imagenet +python -m tools.datasets.convert imagenet IMAGENET_FOLDER --split train +# ucf101 +python -m tools.datasets.convert ucf101 UCF101_FOLDER --split videos +# vidprom +python -m tools.datasets.convert vidprom VIDPROM_FOLDER --info VidProM_semantic_unique.csv +``` + +## Manage datasets + +Use `datautil` to manage the dataset. + +### Requirement + +To accelerate processing speed, you can install [pandarallel](https://github.com/nalepae/pandarallel): + +```bash +pip install pandarallel +``` + +To get image and video information, you need to install [opencv-python](https://github.com/opencv/opencv-python): + +```bash +pip install opencv-python +# If your videos are in av1 codec instead of h264, you need to +# - install ffmpeg first +# - install via conda to support av1 codec +conda install -c conda-forge opencv +``` + +Or to get video information, you can install ffmpeg and ffmpeg-python: + +```bash +pip install ffmpeg-python +``` + +To filter a specific language, you need to install [lingua](https://github.com/pemistahl/lingua-py): + +```bash +pip install lingua-language-detector +``` + +### Basic Usage + +You can use the following commands to process the `csv` or `parquet` files. The output file will be saved in the same directory as the input, with different suffixes indicating the processed method. + +```bash +# datautil takes multiple CSV files as input and merge them into one CSV file +# output: DATA1+DATA2.csv +python -m tools.datasets.datautil DATA1.csv DATA2.csv + +# shard CSV files into multiple CSV files +# output: DATA1_0.csv, DATA1_1.csv, ... +python -m tools.datasets.datautil DATA1.csv --shard 10 + +# filter frames between 128 and 256, with captions +# output: DATA1_fmin_128_fmax_256.csv +python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256 + +# Disable parallel processing +python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256 --disable-parallel + +# Compute num_frames, height, width, fps, aspect_ratio for videos or images +# output: IMG_DATA+VID_DATA_vinfo.csv +python -m tools.datasets.datautil IMG_DATA.csv VID_DATA.csv --video-info + +# You can run multiple operations at the same time. +python -m tools.datasets.datautil DATA.csv --video-info --remove-empty-caption --remove-url --lang en +``` + +### Score filtering + +To examine and filter the quality of the dataset by aesthetic score and clip score, you can use the following commands: + +```bash +# sort the dataset by aesthetic score +# output: DATA_sort.csv +python -m tools.datasets.datautil DATA.csv --sort aesthetic_score +# View examples of high aesthetic score +head -n 10 DATA_sort.csv +# View examples of low aesthetic score +tail -n 10 DATA_sort.csv + +# sort the dataset by clip score +# output: DATA_sort.csv +python -m tools.datasets.datautil DATA.csv --sort clip_score + +# filter the dataset by aesthetic score +# output: DATA_aesmin_0.5.csv +python -m tools.datasets.datautil DATA.csv --aesmin 0.5 +# filter the dataset by clip score +# output: DATA_matchmin_0.5.csv +python -m tools.datasets.datautil DATA.csv --matchmin 0.5 +``` + +### Documentation + +You can also use `python -m tools.datasets.datautil --help` to see usage. + +| Args | File suffix | Description | +| --------------------------- | -------------- | ------------------------------------------------------------- | +| `--output OUTPUT` | | Output path | +| `--format FORMAT` | | Output format (csv, parquet, parquet.gzip) | +| `--disable-parallel` | | Disable `pandarallel` | +| `--seed SEED` | | Random seed | +| `--shard SHARD` | `_0`,`_1`, ... | Shard the dataset | +| `--sort KEY` | `_sort` | Sort the dataset by KEY | +| `--sort-descending KEY` | `_sort` | Sort the dataset by KEY in descending order | +| `--difference DATA.csv` | | Remove the paths in DATA.csv from the dataset | +| `--intersection DATA.csv` | | Keep the paths in DATA.csv from the dataset and merge columns | +| `--info` | `_info` | Get the basic information of each video and image (cv2) | +| `--ext` | `_ext` | Remove rows if the file does not exist | +| `--relpath` | `_relpath` | Modify the path to relative path by root given | +| `--abspath` | `_abspath` | Modify the path to absolute path by root given | +| `--remove-empty-caption` | `_noempty` | Remove rows with empty caption | +| `--remove-url` | `_nourl` | Remove rows with url in caption | +| `--lang LANG` | `_lang` | Remove rows with other language | +| `--remove-path-duplication` | `_noduppath` | Remove rows with duplicated path | +| `--remove-text-duplication` | `_noduptext` | Remove rows with duplicated caption | +| `--refine-llm-caption` | `_llm` | Modify the caption generated by LLM | +| `--clean-caption MODEL` | `_clean` | Modify the caption according to T5 pipeline to suit training | +| `--unescape` | `_unescape` | Unescape the caption | +| `--merge-cmotion` | `_cmotion` | Merge the camera motion to the caption | +| `--count-num-token` | `_ntoken` | Count the number of tokens in the caption | +| `--load-caption EXT` | `_load` | Load the caption from the file | +| `--fmin FMIN` | `_fmin` | Filter the dataset by minimum number of frames | +| `--fmax FMAX` | `_fmax` | Filter the dataset by maximum number of frames | +| `--hwmax HWMAX` | `_hwmax` | Filter the dataset by maximum height x width | +| `--aesmin AESMIN` | `_aesmin` | Filter the dataset by minimum aesthetic score | +| `--matchmin MATCHMIN` | `_matchmin` | Filter the dataset by minimum clip score | +| `--flowmin FLOWMIN` | `_flowmin` | Filter the dataset by minimum optical flow score | + +## Transform datasets + +The `tools.datasets.transform` module provides a set of tools to transform the dataset. The general usage is as follows: + +```bash +python -m tools.datasets.transform TRANSFORM_TYPE META.csv ORIGINAL_DATA_FOLDER DATA_FOLDER_TO_SAVE_RESULTS --additional-args +``` + +### Resize + +Sometimes you may need to resize the images or videos to a specific resolution. You can use the following commands to resize the dataset: + +```bash +python -m tools.datasets.transform meta.csv /path/to/raw/data /path/to/new/data --length 2160 +``` + +### Frame extraction + +To extract frames from videos, you can use the following commands: + +```bash +python -m tools.datasets.transform vid_frame_extract meta.csv /path/to/raw/data /path/to/new/data --points 0.1 0.5 0.9 +``` + +### Crop Midjourney 4 grid + +Randomly select one of the 4 images in the 4 grid generated by Midjourney. + +```bash +python -m tools.datasets.transform img_rand_crop meta.csv /path/to/raw/data /path/to/new/data +``` + +## Analyze datasets + +You can easily get basic information about a `.csv` dataset by using the following commands: + +```bash +# examine the first 10 rows of the CSV file +head -n 10 DATA1.csv +# count the number of data in the CSV file (approximately) +wc -l DATA1.csv +``` + +For the dataset provided in a `.csv` or `.parquet` file, you can easily analyze the dataset using the following commands. Plots will be automatically saved. + +```python +pyhton -m tools.datasets.analyze DATA_info.csv +``` + +## Data Process Pipeline + +```bash +# Suppose videos and images under ~/dataset/ +# 1. Convert dataset to CSV +python -m tools.datasets.convert video ~/dataset --output meta.csv + +# 2. Get video information +python -m tools.datasets.datautil meta.csv --info --fmin 1 + +# 3. Get caption +# 3.1. generate caption +torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava meta_info_fmin1.csv --dp-size 8 --tp-size 1 --model-path liuhaotian/llava-v1.6-mistral-7b --prompt video +# merge generated results +python -m tools.datasets.datautil meta_info_fmin1_caption_part*.csv --output meta_caption.csv +# merge caption and info +python -m tools.datasets.datautil meta_info_fmin1.csv --intersection meta_caption.csv --output meta_caption_info.csv +# clean caption +python -m tools.datasets.datautil meta_caption_info.csv --clean-caption --refine-llm-caption --remove-empty-caption --output meta_caption_processed.csv +# 3.2. extract caption +python -m tools.datasets.datautil meta_info_fmin1.csv --load-caption json --remove-empty-caption --clean-caption + +# 4. Scoring +# aesthetic scoring +torchrun --standalone --nproc_per_node 8 -m tools.scoring.aesthetic.inference meta_caption_processed.csv +python -m tools.datasets.datautil meta_caption_processed_part*.csv --output meta_caption_processed_aes.csv +# optical flow scoring +torchrun --standalone --nproc_per_node 8 -m tools.scoring.optical_flow.inference meta_caption_processed.csv +# matching scoring +torchrun --standalone --nproc_per_node 8 -m tools.scoring.matching.inference meta_caption_processed.csv +# camera motion +python -m tools.caption.camera_motion_detect meta_caption_processed.csv +``` diff --git a/src/videogen_hub/pipelines/opensora/tools/datasets/__init__.py b/src/videogen_hub/pipelines/opensora/tools/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/datasets/analyze.py b/src/videogen_hub/pipelines/opensora/tools/datasets/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..7151689a4d309e5516f1a461fe4bec47dbff97e2 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/datasets/analyze.py @@ -0,0 +1,96 @@ +import argparse +import os + +import matplotlib.pyplot as plt +import pandas as pd + + +def read_file(input_path): + if input_path.endswith(".csv"): + return pd.read_csv(input_path) + elif input_path.endswith(".parquet"): + return pd.read_parquet(input_path) + else: + raise NotImplementedError(f"Unsupported file format: {input_path}") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, help="Path to the input dataset") + parser.add_argument("--save-img", type=str, default="samples/infos/", help="Path to save the image") + return parser.parse_args() + + +def plot_data(data, column, bins, name): + plt.clf() + data.hist(column=column, bins=bins) + os.makedirs(os.path.dirname(name), exist_ok=True) + plt.savefig(name) + print(f"Saved {name}") + + +def plot_categorical_data(data, column, name): + plt.clf() + data[column].value_counts().plot(kind="bar") + os.makedirs(os.path.dirname(name), exist_ok=True) + plt.savefig(name) + print(f"Saved {name}") + + +COLUMNS = { + "num_frames": 100, + "resolution": 100, + "text_len": 100, + "aes": 100, + "match": 100, + "flow": 100, + "cmotion": None, +} + + +def main(args): + data = read_file(args.input) + + # === Image Data Info === + image_index = data["num_frames"] == 1 + if image_index.sum() > 0: + print("=== Image Data Info ===") + img_data = data[image_index] + print(f"Number of images: {len(img_data)}") + print(img_data.head()) + print(img_data.describe()) + if args.save_img: + for column in COLUMNS: + if column in img_data.columns and column not in ["num_frames", "cmotion"]: + if COLUMNS[column] is None: + plot_categorical_data(img_data, column, os.path.join(args.save_img, f"image_{column}.png")) + else: + plot_data(img_data, column, COLUMNS[column], os.path.join(args.save_img, f"image_{column}.png")) + + # === Video Data Info === + if not image_index.all(): + print("=== Video Data Info ===") + video_data = data[~image_index] + print(f"Number of videos: {len(video_data)}") + if "num_frames" in video_data.columns: + total_num_frames = video_data["num_frames"].sum() + print(f"Number of frames: {total_num_frames}") + DEFAULT_FPS = 30 + total_hours = total_num_frames / DEFAULT_FPS / 3600 + print(f"Total hours (30 FPS): {int(total_hours)}") + print(video_data.head()) + print(video_data.describe()) + if args.save_img: + for column in COLUMNS: + if column in video_data.columns: + if COLUMNS[column] is None: + plot_categorical_data(video_data, column, os.path.join(args.save_img, f"video_{column}.png")) + else: + plot_data( + video_data, column, COLUMNS[column], os.path.join(args.save_img, f"video_{column}.png") + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/videogen_hub/pipelines/opensora/tools/datasets/convert.py b/src/videogen_hub/pipelines/opensora/tools/datasets/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..ef6eee3ffa31cc7e0c2620ae826e69c8e8960631 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/datasets/convert.py @@ -0,0 +1,135 @@ +import argparse +import os +import time + +import pandas as pd +from torchvision.datasets import ImageNet + +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") + + +def scan_recursively(root): + num = 0 + for entry in os.scandir(root): + if entry.is_file(): + yield entry + elif entry.is_dir(): + num += 1 + if num % 100 == 0: + print(f"Scanned {num} directories.") + yield from scan_recursively(entry.path) + + +def get_filelist(file_path, exts=None): + filelist = [] + time_start = time.time() + + # == OS Walk == + # for home, dirs, files in os.walk(file_path): + # for filename in files: + # ext = os.path.splitext(filename)[-1].lower() + # if exts is None or ext in exts: + # filelist.append(os.path.join(home, filename)) + + # == Scandir == + obj = scan_recursively(file_path) + for entry in obj: + if entry.is_file(): + ext = os.path.splitext(entry.name)[-1].lower() + if exts is None or ext in exts: + filelist.append(entry.path) + + time_end = time.time() + print(f"Scanned {len(filelist)} files in {time_end - time_start:.2f} seconds.") + return filelist + + +def split_by_capital(name): + # BoxingPunchingBag -> Boxing Punching Bag + new_name = "" + for i in range(len(name)): + if name[i].isupper() and i != 0: + new_name += " " + new_name += name[i] + return new_name + + +def process_imagenet(root, split): + root = os.path.expanduser(root) + data = ImageNet(root, split=split) + samples = [(path, data.classes[label][0]) for path, label in data.samples] + output = f"imagenet_{split}.csv" + + df = pd.DataFrame(samples, columns=["path", "text"]) + df.to_csv(output, index=False) + print(f"Saved {len(samples)} samples to {output}.") + + +def process_ucf101(root, split): + root = os.path.expanduser(root) + video_lists = get_filelist(os.path.join(root, split)) + classes = [x.split("/")[-2] for x in video_lists] + classes = [split_by_capital(x) for x in classes] + samples = list(zip(video_lists, classes)) + output = f"ucf101_{split}.csv" + + df = pd.DataFrame(samples, columns=["path", "text"]) + df.to_csv(output, index=False) + print(f"Saved {len(samples)} samples to {output}.") + + +def process_vidprom(root, info): + root = os.path.expanduser(root) + video_lists = get_filelist(root) + video_set = set(video_lists) + # read info csv + infos = pd.read_csv(info) + abs_path = infos["uuid"].apply(lambda x: os.path.join(root, f"pika-{x}.mp4")) + is_exist = abs_path.apply(lambda x: x in video_set) + df = pd.DataFrame(dict(path=abs_path[is_exist], text=infos["prompt"][is_exist])) + df.to_csv("vidprom.csv", index=False) + print(f"Saved {len(df)} samples to vidprom.csv.") + + +def process_general_images(root, output): + root = os.path.expanduser(root) + image_lists = get_filelist(root, IMG_EXTENSIONS) + df = pd.DataFrame(dict(path=image_lists)) + if output is None: + output = "images.csv" + df.to_csv(output, index=False) + print(f"Saved {len(df)} samples to {output}.") + + +def process_general_videos(root, output): + root = os.path.expanduser(root) + video_lists = get_filelist(root, VID_EXTENSIONS) + df = pd.DataFrame(dict(path=video_lists)) + if output is None: + output = "videos.csv" + df.to_csv(output, index=False) + print(f"Saved {len(df)} samples to {output}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom", "image", "video"]) + parser.add_argument("root", type=str) + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--info", type=str, default=None) + parser.add_argument("--output", type=str, default=None) + args = parser.parse_args() + + if args.dataset == "imagenet": + process_imagenet(args.root, args.split) + elif args.dataset == "ucf101": + process_ucf101(args.root, args.split) + elif args.dataset == "vidprom": + process_vidprom(args.root, args.info) + elif args.dataset == "image": + process_general_images(args.root, args.output) + elif args.dataset == "video": + process_general_videos(args.root, args.output) + else: + raise ValueError("Invalid dataset") diff --git a/src/videogen_hub/pipelines/opensora/tools/datasets/datautil.py b/src/videogen_hub/pipelines/opensora/tools/datasets/datautil.py new file mode 100644 index 0000000000000000000000000000000000000000..475b847258e384c8c82af65a9e41fe15641d26ad --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/datasets/datautil.py @@ -0,0 +1,695 @@ +import argparse +import html +import json +import os +import random +import re +from functools import partial +from glob import glob + +import cv2 +import numpy as np +import pandas as pd +import torchvision +from tqdm import tqdm + +from .utils import IMG_EXTENSIONS + +tqdm.pandas() + +try: + from pandarallel import pandarallel + + PANDA_USE_PARALLEL = True +except ImportError: + PANDA_USE_PARALLEL = False + + +def apply(df, func, **kwargs): + if PANDA_USE_PARALLEL: + return df.parallel_apply(func, **kwargs) + return df.progress_apply(func, **kwargs) + + +TRAIN_COLUMNS = ["path", "text", "num_frames", "fps", "height", "width", "aspect_ratio", "resolution", "text_len"] + +# ====================================================== +# --info +# ====================================================== + + +def get_video_length(cap, method="header"): + assert method in ["header", "set"] + if method == "header": + length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + else: + cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1) + length = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) + return length + + +def get_info(path): + try: + ext = os.path.splitext(path)[1].lower() + if ext in IMG_EXTENSIONS: + im = cv2.imread(path) + if im is None: + return 0, 0, 0, np.nan, np.nan + height, width = im.shape[:2] + num_frames, fps = 1, np.nan + else: + cap = cv2.VideoCapture(path) + num_frames, height, width, fps = ( + get_video_length(cap, method="header"), + int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), + int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + float(cap.get(cv2.CAP_PROP_FPS)), + ) + hw = height * width + aspect_ratio = height / width if width > 0 else np.nan + return num_frames, height, width, aspect_ratio, fps, hw + except: + return 0, 0, 0, np.nan, np.nan, np.nan + + +def get_video_info(path): + try: + vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3] + aspect_ratio = height / width + fps = np.nan + resolution = height * width + return num_frames, height, width, aspect_ratio, fps, resolution + except: + return 0, 0, 0, np.nan, np.nan, np.nan + + +# ====================================================== +# --refine-llm-caption +# ====================================================== + +LLAVA_PREFIX = [ + "The video shows", + "The video captures", + "The video features", + "The video depicts", + "The video presents", + "The video features", + "The video is ", + "In the video,", + "The image shows", + "The image captures", + "The image features", + "The image depicts", + "The image presents", + "The image features", + "The image is ", + "The image portrays", + "In the image,", +] + + +def remove_caption_prefix(caption): + for prefix in LLAVA_PREFIX: + if caption.startswith(prefix) or caption.startswith(prefix.lower()): + caption = caption[len(prefix) :].strip() + if caption[0].islower(): + caption = caption[0].upper() + caption[1:] + return caption + return caption + + +# ====================================================== +# --merge-cmotion +# ====================================================== + +CMOTION_TEXT = { + "static": "The camera is static.", + "dynamic": "The camera is moving.", + "unknown": None, + "zoom in": "The camera is zooming in.", + "zoom out": "The camera is zooming out.", + "pan left": "The camera is panning left.", + "pan right": "The camera is panning right.", + "tilt up": "The camera is tilting up.", + "tilt down": "The camera is tilting down.", + "pan/tilt": "The camera is panning.", +} +CMOTION_PROBS = { + # hard-coded probabilities + "static": 1.0, + "dynamic": 1.0, + "unknown": 0.0, + "zoom in": 1.0, + "zoom out": 1.0, + "pan left": 1.0, + "pan right": 1.0, + "tilt up": 1.0, + "tilt down": 1.0, + "pan/tilt": 1.0, +} + + +def merge_cmotion(caption, cmotion): + text = CMOTION_TEXT[cmotion] + prob = CMOTION_PROBS[cmotion] + if text is not None and random.random() < prob: + caption = f"{caption} {text}" + return caption + + +# ====================================================== +# --lang +# ====================================================== + + +def build_lang_detector(lang_to_detect): + from lingua import Language, LanguageDetectorBuilder + + lang_dict = dict(en=Language.ENGLISH) + assert lang_to_detect in lang_dict + valid_lang = lang_dict[lang_to_detect] + detector = LanguageDetectorBuilder.from_all_spoken_languages().with_low_accuracy_mode().build() + + def detect_lang(caption): + confidence_values = detector.compute_language_confidence_values(caption) + confidence = [x.language for x in confidence_values[:5]] + if valid_lang not in confidence: + return False + return True + + return detect_lang + + +# ====================================================== +# --clean-caption +# ====================================================== + + +def basic_clean(text): + import ftfy + + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +BAD_PUNCT_REGEX = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" +) # noqa + + +def clean_caption(caption): + import urllib.parse as ul + + from bs4 import BeautifulSoup + + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = basic_clean(caption) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + +def text_preprocessing(text, use_text_preprocessing: bool = True): + if use_text_preprocessing: + # The exact text cleaning as was in the training stage: + text = clean_caption(text) + text = clean_caption(text) + return text + else: + return text.lower().strip() + + +# ====================================================== +# load caption +# ====================================================== + + +def load_caption(path, ext): + try: + assert ext in ["json"] + json_path = path.split(".")[0] + ".json" + with open(json_path, "r") as f: + data = json.load(f) + caption = data["caption"] + return caption + except: + return "" + + +# ====================================================== +# read & write +# ====================================================== + + +def read_file(input_path): + if input_path.endswith(".csv"): + return pd.read_csv(input_path) + elif input_path.endswith(".parquet"): + return pd.read_parquet(input_path) + else: + raise NotImplementedError(f"Unsupported file format: {input_path}") + + +def save_file(data, output_path): + output_dir = os.path.dirname(output_path) + if not os.path.exists(output_dir) and output_dir != "": + os.makedirs(output_dir) + if output_path.endswith(".csv"): + return data.to_csv(output_path, index=False) + elif output_path.endswith(".parquet"): + return data.to_parquet(output_path, index=False) + else: + raise NotImplementedError(f"Unsupported file format: {output_path}") + + +def read_data(input_paths): + data = [] + input_name = "" + input_list = [] + for input_path in input_paths: + input_list.extend(glob(input_path)) + print("Input files:", input_list) + for i, input_path in enumerate(input_list): + assert os.path.exists(input_path) + data.append(read_file(input_path)) + input_name += os.path.basename(input_path).split(".")[0] + if i != len(input_list) - 1: + input_name += "+" + print(f"Loaded {len(data[-1])} samples from {input_path}.") + data = pd.concat(data, ignore_index=True, sort=False) + print(f"Total number of samples: {len(data)}.") + return data, input_name + + +# ====================================================== +# main +# ====================================================== +# To add a new method, register it in the main, parse_args, and get_output_path functions, and update the doc at /tools/datasets/README.md#documentation + + +def main(args): + # reading data + data, input_name = read_data(args.input) + + # make difference + if args.difference is not None: + data_diff = pd.read_csv(args.difference) + print(f"Difference csv contains {len(data_diff)} samples.") + data = data[~data["path"].isin(data_diff["path"])] + input_name += f"-{os.path.basename(args.difference).split('.')[0]}" + print(f"Filtered number of samples: {len(data)}.") + + # make intersection + if args.intersection is not None: + data_new = pd.read_csv(args.intersection) + print(f"Intersection csv contains {len(data_new)} samples.") + cols_to_use = data_new.columns.difference(data.columns) + cols_to_use = cols_to_use.insert(0, "path") + data = pd.merge(data, data_new[cols_to_use], on="path", how="inner") + print(f"Intersection number of samples: {len(data)}.") + + # train columns + if args.train_column: + all_columns = data.columns + columns_to_drop = all_columns.difference(TRAIN_COLUMNS) + data = data.drop(columns=columns_to_drop) + + # get output path + output_path = get_output_path(args, input_name) + + # preparation + if args.lang is not None: + detect_lang = build_lang_detector(args.lang) + if args.count_num_token == "t5": + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("DeepFloyd/t5-v1_1-xxl") + + # IO-related + if args.load_caption is not None: + assert "path" in data.columns + data["text"] = apply(data["path"], load_caption, ext=args.load_caption) + if args.info: + info = apply(data["path"], get_info) + ( + data["num_frames"], + data["height"], + data["width"], + data["aspect_ratio"], + data["fps"], + data["resolution"], + ) = zip(*info) + if args.video_info: + info = apply(data["path"], get_video_info) + ( + data["num_frames"], + data["height"], + data["width"], + data["aspect_ratio"], + data["fps"], + data["resolution"], + ) = zip(*info) + if args.ext: + assert "path" in data.columns + data = data[apply(data["path"], os.path.exists)] + + # filtering + if args.remove_url: + assert "text" in data.columns + data = data[~data["text"].str.contains(r"(?Phttps?://[^\s]+)", regex=True)] + if args.lang is not None: + assert "text" in data.columns + data = data[data["text"].progress_apply(detect_lang)] # cannot parallelize + if args.remove_empty_caption: + assert "text" in data.columns + data = data[data["text"].str.len() > 0] + data = data[~data["text"].isna()] + if args.remove_path_duplication: + assert "path" in data.columns + data = data.drop_duplicates(subset=["path"]) + + # processing + if args.relpath is not None: + data["path"] = apply(data["path"], lambda x: os.path.relpath(x, args.relpath)) + if args.abspath is not None: + data["path"] = apply(data["path"], lambda x: os.path.join(args.abspath, x)) + if args.merge_cmotion: + data["text"] = apply(data, lambda x: merge_cmotion(x["text"], x["cmotion"]), axis=1) + if args.refine_llm_caption: + assert "text" in data.columns + data["text"] = apply(data["text"], remove_caption_prefix) + if args.clean_caption: + assert "text" in data.columns + data["text"] = apply( + data["text"], + partial(text_preprocessing, use_text_preprocessing=True), + ) + + if args.count_num_token is not None: + assert "text" in data.columns + data["text_len"] = apply(data["text"], lambda x: len(tokenizer(x)["input_ids"])) + + # sort + if args.sort is not None: + data = data.sort_values(by=args.sort, ascending=False) + if args.sort_ascending is not None: + data = data.sort_values(by=args.sort_ascending, ascending=True) + + # filtering + if args.remove_empty_caption: + assert "text" in data.columns + data = data[data["text"].str.len() > 0] + data = data[~data["text"].isna()] + if args.fmin is not None: + assert "num_frames" in data.columns + data = data[data["num_frames"] >= args.fmin] + if args.fmax is not None: + assert "num_frames" in data.columns + data = data[data["num_frames"] <= args.fmax] + if args.hwmax is not None: + if "resolution" not in data.columns: + height = data["height"] + width = data["width"] + data["resolution"] = height * width + data = data[data["resolution"] <= args.hwmax] + if args.aesmin is not None: + assert "aes" in data.columns + data = data[data["aes"] >= args.aesmin] + if args.matchmin is not None: + assert "match" in data.columns + data = data[data["match"] >= args.matchmin] + if args.flowmin is not None: + assert "flow" in data.columns + data = data[data["flow"] >= args.flowmin] + if args.remove_text_duplication: + data = data.drop_duplicates(subset=["text"], keep="first") + print(f"Filtered number of samples: {len(data)}.") + + # shard data + if args.shard is not None: + sharded_data = np.array_split(data, args.shard) + for i in range(args.shard): + output_path_part = output_path.split(".") + output_path_s = ".".join(output_path_part[:-1]) + f"_{i}." + output_path_part[-1] + save_file(sharded_data[i], output_path_s) + print(f"Saved {len(sharded_data[i])} samples to {output_path_s}.") + else: + save_file(data, output_path) + print(f"Saved {len(data)} samples to {output_path}.") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, nargs="+", help="path to the input dataset") + parser.add_argument("--output", type=str, default=None, help="output path") + parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"]) + parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing") + parser.add_argument("--num-workers", type=int, default=None, help="number of workers") + parser.add_argument("--seed", type=int, default=None, help="random seed") + + # special case + parser.add_argument("--shard", type=int, default=None, help="shard the dataset") + parser.add_argument("--sort", type=str, default=None, help="sort by column") + parser.add_argument("--sort-ascending", type=str, default=None, help="sort by column (ascending order)") + parser.add_argument("--difference", type=str, default=None, help="get difference from the dataset") + parser.add_argument( + "--intersection", type=str, default=None, help="keep the paths in csv from the dataset and merge columns" + ) + parser.add_argument("--train-column", action="store_true", help="only keep the train column") + + # IO-related + parser.add_argument("--info", action="store_true", help="get the basic information of each video and image") + parser.add_argument("--video-info", action="store_true", help="get the basic information of each video") + parser.add_argument("--ext", action="store_true", help="check if the file exists") + parser.add_argument( + "--load-caption", type=str, default=None, choices=["json", "txt"], help="load the caption from json or txt" + ) + + # path processing + parser.add_argument("--relpath", type=str, default=None, help="modify the path to relative path by root given") + parser.add_argument("--abspath", type=str, default=None, help="modify the path to absolute path by root given") + + # caption filtering + parser.add_argument( + "--remove-empty-caption", + action="store_true", + help="remove rows with empty caption", + ) + parser.add_argument("--remove-url", action="store_true", help="remove rows with url in caption") + parser.add_argument("--lang", type=str, default=None, help="remove rows with other language") + parser.add_argument("--remove-path-duplication", action="store_true", help="remove rows with duplicated path") + parser.add_argument("--remove-text-duplication", action="store_true", help="remove rows with duplicated caption") + + # caption processing + parser.add_argument("--refine-llm-caption", action="store_true", help="modify the caption generated by LLM") + parser.add_argument( + "--clean-caption", action="store_true", help="modify the caption according to T5 pipeline to suit training" + ) + parser.add_argument("--merge-cmotion", action="store_true", help="merge the camera motion to the caption") + parser.add_argument( + "--count-num-token", type=str, choices=["t5"], default=None, help="Count the number of tokens in the caption" + ) + + # score filtering + parser.add_argument("--fmin", type=int, default=None, help="filter the dataset by minimum number of frames") + parser.add_argument("--fmax", type=int, default=None, help="filter the dataset by maximum number of frames") + parser.add_argument("--hwmax", type=int, default=None, help="filter the dataset by maximum resolution") + parser.add_argument("--aesmin", type=float, default=None, help="filter the dataset by minimum aes score") + parser.add_argument("--matchmin", type=float, default=None, help="filter the dataset by minimum match score") + parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score") + + return parser.parse_args() + + +def get_output_path(args, input_name): + if args.output is not None: + return args.output + name = input_name + dir_path = os.path.dirname(args.input[0]) + + # sort + if args.sort is not None: + assert args.sort_ascending is None + name += "_sort" + if args.sort_ascending is not None: + assert args.sort is None + name += "_sort" + + # IO-related + # for IO-related, the function must be wrapped in try-except + if args.info: + name += "_info" + if args.video_info: + name += "_vinfo" + if args.ext: + name += "_ext" + if args.load_caption: + name += f"_load{args.load_caption}" + + # path processing + if args.relpath is not None: + name += "_relpath" + if args.abspath is not None: + name += "_abspath" + + # caption filtering + if args.remove_empty_caption: + name += "_noempty" + if args.remove_url: + name += "_nourl" + if args.lang is not None: + name += f"_{args.lang}" + if args.remove_path_duplication: + name += "_noduppath" + if args.remove_text_duplication: + name += "_noduptext" + + # caption processing + if args.refine_llm_caption: + name += "_llm" + if args.clean_caption: + name += "_clean" + if args.merge_cmotion: + name += "_cmcaption" + if args.count_num_token: + name += "_ntoken" + + # score filtering + if args.fmin is not None: + name += f"_fmin{args.fmin}" + if args.fmax is not None: + name += f"_fmax{args.fmax}" + if args.hwmax is not None: + name += f"_hwmax{args.hwmax}" + if args.aesmin is not None: + name += f"_aesmin{args.aesmin}" + if args.matchmin is not None: + name += f"_matchmin{args.matchmin}" + if args.flowmin is not None: + name += f"_flowmin{args.flowmin}" + + output_path = os.path.join(dir_path, f"{name}.{args.format}") + return output_path + + +if __name__ == "__main__": + args = parse_args() + if args.disable_parallel: + PANDA_USE_PARALLEL = False + if PANDA_USE_PARALLEL: + if args.num_workers is not None: + pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True) + else: + pandarallel.initialize(progress_bar=True) + if args.seed is not None: + random.seed(args.seed) + np.random.seed(args.seed) + main(args) diff --git a/src/videogen_hub/pipelines/opensora/tools/datasets/filter_panda10m.py b/src/videogen_hub/pipelines/opensora/tools/datasets/filter_panda10m.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c3ef34e1acb28d2d0b0c789631d00560c8251b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/datasets/filter_panda10m.py @@ -0,0 +1,262 @@ +# TODO: remove this file before releasing + +import argparse +import os +import pandas as pd +import json +import html +from tqdm import tqdm +import re + +tqdm.pandas() + +try: + from pandarallel import pandarallel + + pandarallel.initialize(progress_bar=True) + pandas_has_parallel = True +except ImportError: + pandas_has_parallel = False + + +def apply(df, func, **kwargs): + if pandas_has_parallel: + return df.parallel_apply(func, **kwargs) + return df.progress_apply(func, **kwargs) + + +def basic_clean(text): + import ftfy + + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +BAD_PUNCT_REGEX = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" +) # noqa + + +def clean_caption(caption): + import urllib.parse as ul + + from bs4 import BeautifulSoup + + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = basic_clean(caption) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + +def get_10m_set(): + meta_path_10m = '/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv' + meta_10m = pd.read_csv(meta_path_10m) + + def process_single_caption(row): + text_list = eval(row['caption']) + clean_list = [clean_caption(x) for x in text_list] + return str(clean_list) + + ret = apply(meta_10m, process_single_caption, axis=1) + # ret = meta_10m.progress_apply(process_single_caption, axis=1) + print('==> text processed.') + + text_list = [] + for x in ret: + text_list += eval(x) + # text_set = text_set.union(set(eval(x))) + text_set = set(text_list) + # meta_10m['caption_new'] = ret + # meta_10m.to_csv('/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m_new-cap.csv') + + # video_id_set = set(meta_10m['videoID']) + # id2t = {} + # for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)): + # video_id = row['videoID'] + # text_list = eval(row['caption']) + # id2t[video_id] = set(text_list) + + print(f"==> Loaded meta_10m from '{meta_path_10m}'") + return text_set + + +def filter_panda10m_text(meta_path, text_set): + def process_single_row(row): + # path = row['path'] + t = row['text'] + # fname = os.path.basename(path) + # video_id = fname[:fname.rindex('_')] + if t not in text_set: + return False + return True + + meta = pd.read_csv(meta_path) + ret = apply(meta, process_single_row, axis=1) + # ret = meta.progress_apply(process_single_row, axis=1) + + meta = meta[ret] + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_filter-10m{ext}" + meta.to_csv(out_path, index=False) + print(f"New meta (shape={meta.shape}) saved to '{out_path}'.") + + +def filter_panda10m_timestamp(meta_path): + meta_path_10m = '/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv' + meta_10m = pd.read_csv(meta_path_10m) + + id2t = {} + for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)): + video_id = row['videoID'] + timestamp = eval(row['timestamp']) + timestamp = [str(tuple(x)) for x in timestamp] + id2t[video_id] = timestamp + + # video_id_set_10m = set(meta_10m['videoID']) + print(f"==> Loaded meta_10m from '{meta_path_10m}'") + + def process_single_row(row): + path = row['path'] + t = row['timestamp'] + fname = os.path.basename(path) + video_id = fname[:fname.rindex('_')] + if video_id not in id2t: + return False + if t not in id2t[video_id]: + return False + return True + # return video_id in video_id_set_10m + + meta = pd.read_csv(meta_path) + ret = apply(meta, process_single_row, axis=1) + + meta = meta[ret] + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_filter-10m{ext}" + meta.to_csv(out_path, index=False) + print(f"New meta (shape={meta.shape}) saved to '{out_path}'.") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--meta_path', type=str, nargs='+') + parser.add_argument('--num_workers', default=5, type=int) + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + text_set = get_10m_set() + for x in args.meta_path: + filter_panda10m_text(x, text_set) diff --git a/src/videogen_hub/pipelines/opensora/tools/datasets/split.py b/src/videogen_hub/pipelines/opensora/tools/datasets/split.py new file mode 100644 index 0000000000000000000000000000000000000000..35107a051f5874765fca6925cb07cc0239f7d375 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/datasets/split.py @@ -0,0 +1,72 @@ +import argparse +from typing import List + +import pandas as pd +from mmengine.config import Config + +from videogen_hub.pipelines.opensora.opensora.datasets.bucket import Bucket + + +def split_by_bucket( + bucket: Bucket, + input_files: List[str], + output_path: str, + limit: int, + frame_interval: int, +): + print(f"Split {len(input_files)} files into {len(bucket)} buckets") + total_limit = len(bucket) * limit + bucket_cnt = {} + # get all bucket id + for hw_id, d in bucket.ar_criteria.items(): + for t_id, v in d.items(): + for ar_id in v.keys(): + bucket_id = (hw_id, t_id, ar_id) + bucket_cnt[bucket_id] = 0 + output_df = None + # split files + for path in input_files: + df = pd.read_csv(path) + if output_df is None: + output_df = pd.DataFrame(columns=df.columns) + for i in range(len(df)): + row = df.iloc[i] + t, h, w = row["num_frames"], row["height"], row["width"] + bucket_id = bucket.get_bucket_id(t, h, w, frame_interval) + if bucket_id is None: + continue + if bucket_cnt[bucket_id] < limit: + bucket_cnt[bucket_id] += 1 + output_df = pd.concat([output_df, pd.DataFrame([row])], ignore_index=True) + if len(output_df) >= total_limit: + break + if len(output_df) >= total_limit: + break + assert len(output_df) <= total_limit + if len(output_df) == total_limit: + print(f"All buckets are full ({total_limit} samples)") + else: + print(f"Only {len(output_df)} files are used") + output_df.to_csv(output_path, index=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, nargs="+") + parser.add_argument("-o", "--output", required=True) + parser.add_argument("-c", "--config", required=True) + parser.add_argument("-l", "--limit", default=200, type=int) + args = parser.parse_args() + assert args.limit > 0 + + cfg = Config.fromfile(args.config) + bucket_config = cfg.bucket_config + # rewrite bucket_config + for ar, d in bucket_config.items(): + for frames, t in d.items(): + p, bs = t + if p > 0.0: + p = 1.0 + d[frames] = (p, bs) + bucket = Bucket(bucket_config) + split_by_bucket(bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval) diff --git a/src/videogen_hub/pipelines/opensora/tools/datasets/transform.py b/src/videogen_hub/pipelines/opensora/tools/datasets/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..94a42766d7178e5dec5a0906a3252cde7af8d815 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/datasets/transform.py @@ -0,0 +1,116 @@ +import argparse +import os +import random + +import cv2 +import numpy as np +import pandas as pd +from tqdm import tqdm + +from .utils import IMG_EXTENSIONS, extract_frames + +tqdm.pandas() + +try: + from pandarallel import pandarallel + + pandarallel.initialize(progress_bar=True) + pandas_has_parallel = True +except ImportError: + pandas_has_parallel = False + + +def apply(df, func, **kwargs): + if pandas_has_parallel: + return df.parallel_apply(func, **kwargs) + return df.progress_apply(func, **kwargs) + + +def get_new_path(path, input_dir, output): + path_new = os.path.join(output, os.path.relpath(path, input_dir)) + os.makedirs(os.path.dirname(path_new), exist_ok=True) + return path_new + + +def resize(path, length, input_dir, output): + path_new = get_new_path(path, input_dir, output) + ext = os.path.splitext(path)[1].lower() + assert ext in IMG_EXTENSIONS + img = cv2.imread(path) + h, w = img.shape[:2] + if min(h, w) > length: + if h > w: + new_h = length + new_w = int(w * new_h / h) + else: + new_w = length + new_h = int(h * new_w / w) + img = cv2.resize(img, (new_w, new_h)) + cv2.imwrite(path_new, img) + return path_new + + +def rand_crop(path, input_dir, output): + ext = os.path.splitext(path)[1].lower() + path_new = get_new_path(path, input_dir, output) + assert ext in IMG_EXTENSIONS + img = cv2.imread(path) + h, w = img.shape[:2] + width, height, _ = img.shape + pos = random.randint(0, 3) + if pos == 0: + img_cropped = img[: width // 2, : height // 2] + elif pos == 1: + img_cropped = img[width // 2 :, : height // 2] + elif pos == 2: + img_cropped = img[: width // 2, height // 2 :] + else: + img_cropped = img[width // 2 :, height // 2 :] + cv2.imwrite(path_new, img_cropped) + return path_new + + +def main(args): + data = pd.read_csv(args.input) + if args.method == "img_rand_crop": + data["path"] = apply(data["path"], lambda x: rand_crop(x, args.input_dir, args.output)) + elif args.method == "img_resize": + data["path"] = apply(data["path"], lambda x: resize(x, args.length, args.input_dir, args.output)) + elif args.method == "vid_frame_extract": + points = args.points if args.points is not None else args.points_index + data = pd.DataFrame(np.repeat(data.values, 3, axis=0), columns=data.columns) + num_points = len(points) + data["point"] = np.nan + for i, point in enumerate(points): + if isinstance(point, int): + data.loc[i::num_points, "point"] = point + else: + data.loc[i::num_points, "point"] = data.loc[i::num_points, "num_frames"] * point + data["path"] = apply(data, lambda x: extract_frames(x["path"], args.input_dir, args.output, x["point"]), axis=1) + + output_csv = args.input.replace(".csv", f"_resized{args.length}.csv") + data.to_csv(output_csv, index=False) + print(f"Saved to {output_csv}") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("method", type=str, choices=["img_resize", "img_rand_crop", "vid_frame_extract"]) + parser.add_argument("input", type=str) + parser.add_argument("input_dir", type=str) + parser.add_argument("output", type=str) + parser.add_argument("--disable-parallel", action="store_true") + parser.add_argument("--length", type=int, default=2160) + parser.add_argument("--seed", type=int, default=42, help="seed for random") + parser.add_argument("--points", nargs="+", type=float, default=None) + parser.add_argument("--points_index", nargs="+", type=int, default=None) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + random.seed(args.seed) + if args.disable_parallel: + pandas_has_parallel = False + main(args) diff --git a/src/videogen_hub/pipelines/opensora/tools/datasets/utils.py b/src/videogen_hub/pipelines/opensora/tools/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c91691b0a988d820de099fdb400f5e956b58562b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/datasets/utils.py @@ -0,0 +1,117 @@ +import os + +import cv2 +from PIL import Image + +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") + + +def is_video(filename): + ext = os.path.splitext(filename)[-1].lower() + return ext in VID_EXTENSIONS + + +def extract_frames( + video_path, + frame_inds=None, + points=None, + backend="opencv", + return_length=False, + num_frames=None, +): + """ + Args: + video_path (str): path to video + frame_inds (List[int]): indices of frames to extract + points (List[float]): values within [0, 1); multiply #frames to get frame indices + Return: + List[PIL.Image] + """ + assert backend in ["av", "opencv", "decord"] + assert (frame_inds is None) or (points is None) + + if backend == "av": + import av + + container = av.open(video_path) + if num_frames is not None: + total_frames = num_frames + else: + total_frames = container.streams.video[0].frames + + if points is not None: + frame_inds = [int(p * total_frames) for p in points] + + frames = [] + for idx in frame_inds: + if idx >= total_frames: + idx = total_frames - 1 + target_timestamp = int(idx * av.time_base / container.streams.video[0].average_rate) + container.seek(target_timestamp) + frame = next(container.decode(video=0)).to_image() + frames.append(frame) + + if return_length: + return frames, total_frames + return frames + + elif backend == "decord": + import decord + + container = decord.VideoReader(video_path, num_threads=1) + if num_frames is not None: + total_frames = num_frames + else: + total_frames = len(container) + + if points is not None: + frame_inds = [int(p * total_frames) for p in points] + + frame_inds = np.array(frame_inds).astype(np.int32) + frame_inds[frame_inds >= total_frames] = total_frames - 1 + frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C] + frames = [Image.fromarray(x) for x in frames] + + if return_length: + return frames, total_frames + return frames + + elif backend == "opencv": + cap = cv2.VideoCapture(video_path) + if num_frames is not None: + total_frames = num_frames + else: + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if points is not None: + frame_inds = [int(p * total_frames) for p in points] + + frames = [] + for idx in frame_inds: + if idx >= total_frames: + idx = total_frames - 1 + + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + # HACK: sometimes OpenCV fails to read frames, return a black frame instead + try: + ret, frame = cap.read() + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + except Exception as e: + print(f"Error reading frame {video_path}: {e}") + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame = Image.new("RGB", (width, height), (0, 0, 0)) + # HACK: if height or width is 0, return a black frame instead + if frame.height == 0 or frame.width == 0: + height = width = 256 + frame = Image.new("RGB", (width, height), (0, 0, 0)) + + frames.append(frame) + + if return_length: + return frames, total_frames + return frames + else: + raise ValueError diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/README.md b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8418e6679834459f1c63a425206bcfcd97667d53 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/README.md @@ -0,0 +1,42 @@ +# Frame Interpolation + +For current version, we sample 1 frame out of 3 frames in the video. Although we are going to use VAE to avoid frame loss, we provide a frame interpolation tool to interpolate the video now. The frame interpolation tool is based on [AMT](https://github.com/MCG-NKU/AMT). + +Interpolation can be useful for scenery videos, but it may not be suitable for videos with fast motion. + +## Requirement + +```bash +conda install -c conda-forge opencv +pip install imageio +``` + +## Model + +We use **AMT** as our frame interpolation model. After sampling, you can use frame interpolation model to interpolate your video smoothly. + +## Usage + +The ckpt file will be automatically downloaded in user's `.cache` directory. You can use frame interpolation to your video file or a video folder. + +1. Process a video file + +```python +python -m tools.frame_interpolation.interpolation your_video.mp4 +``` + +2. Process all video file in target directory + +```python +python -m tools.frame_interpolation.interpolation your_video_dir --output_path samples/interpolation +``` + +The output video will be stored at `output_path` and its duration time is equal `the total number of frames after frame interpolation / the frame rate` + +### Command Line Arguments + +* `input`: Path of the input video. **Video path** or **Folder path(with --folder)** +* `--ckpt`: Pretrained model of [AMT](https://github.com/MCG-NKU/AMT). Default path: `~/.cache/amt-g.pth`. +* `--niter`: Iterations of interpolation. With $m$ input frames, `[N_ITER]` $=n$ corresponds to $2^n\times (m-1)+1$ output frames. +* `--fps`: Frame rate of the input video. (Default: 8) +* `--output_path`: **Folder Path** of the output video. diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/__init__.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/interpolation.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..836e6f41e5bf95f90fd0ee3111c5c78147a79ebb --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/interpolation.py @@ -0,0 +1,219 @@ +# this script is modified from https://github.com/MCG-NKU/AMT/blob/main/demos/demo_2x.py +import argparse +import os +import os.path as osp + +import cv2 +import numpy as np +import torch + +from videogen_hub.pipelines.opensora.opensora.utils.ckpt_utils import download_model + +from .networks.amt_g import Model +from .utils.utils import InputPadder, img2tensor, tensor2img + +hf_endpoint = os.environ.get("HF_ENDPOINT") +if hf_endpoint is None: + hf_endpoint = "https://huggingface.co" +VID_EXT = [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm"] +network_cfg = { + "params": { + "corr_radius": 3, + "corr_lvls": 4, + "num_flows": 5, + }, +} +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def init(): + """ + initialize the device and the anchor resolution. + """ + + if device == "cuda": + anchor_resolution = 1024 * 512 + anchor_memory = 1500 * 1024**2 + anchor_memory_bias = 2500 * 1024**2 + vram_avail = torch.cuda.get_device_properties(device).total_memory + print("VRAM available: {:.1f} MB".format(vram_avail / 1024**2)) + else: + # Do not resize in cpu mode + anchor_resolution = 8192 * 8192 + anchor_memory = 1 + anchor_memory_bias = 0 + vram_avail = 1 + + return anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail + + +def get_input_video_from_path(input_path): + """ + Get the input video from the input_path. + + params: + input_path: str, the path of the input video. + devices: str, the device to run the model. + returns: + inputs: list, the list of the input frames. + scale: float, the scale of the input frames. + padder: InputPadder, the padder to pad the input frames. + """ + + anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail = init() + + if osp.splitext(input_path)[-1].lower() in VID_EXT: + vcap = cv2.VideoCapture(input_path) + + inputs = [] + w = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) + scale = 1 if scale > 1 else scale + scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 + if scale < 1: + print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") + padding = int(16 / scale) + padder = InputPadder((h, w), padding) + while True: + ret, frame = vcap.read() + if ret is False: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_t = img2tensor(frame).to(device) + frame_t = padder.pad(frame_t) + inputs.append(frame_t) + print(f"Loading the [video] from {input_path}, the number of frames [{len(inputs)}]") + else: + raise TypeError("Input should be a video.") + + return inputs, scale, padder + + +def load_model(ckpt): + """ + load the frame interpolation model. + """ + params = network_cfg.get("params", {}) + model = Model(**params) + model.load_state_dict(ckpt["state_dict"]) + model = model.to(device) + model.eval() + return model + + +def interpolater(model, inputs, scale, padder, iters=1): + """ + interpolating with the interpolation model. + + params: + model: nn.Module, the frame interpolation model. + inputs: list, the list of the input frames. + scale: float, the scale of the input frames. + iters: int, the number of iterations of interpolation. The final frames model generating is 2 ** iters * (m - 1) + 1 and m is input frames. + returns: + outputs: list, the list of the output frames. + """ + + print("Start frame interpolation:") + embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device) + + for i in range(iters): + print(f"Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}") + outputs = [inputs[0]] + for in_0, in_1 in zip(inputs[:-1], inputs[1:]): + in_0 = in_0.to(device) + in_1 = in_1.to(device) + with torch.no_grad(): + imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)["imgt_pred"] + outputs += [imgt_pred.cpu(), in_1.cpu()] + inputs = outputs + + outputs = padder.unpad(*outputs) + return outputs + + +def write(outputs, input_path, output_path, fps=30): + """ + write results to the output_path. + """ + + if osp.exists(output_path) is False: + os.makedirs(output_path) + + size = outputs[0].shape[2:][::-1] + + _, file_name_with_extension = os.path.split(input_path) + file_name, _ = os.path.splitext(file_name_with_extension) + + save_video_path = f"{output_path}/fps{fps}_{file_name}.mp4" + fourcc = cv2.VideoWriter_fourcc(*"avc1") + writer = cv2.VideoWriter(save_video_path, fourcc, fps, size) + + for i, imgt_pred in enumerate(outputs): + imgt_pred = tensor2img(imgt_pred) + imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) + writer.write(imgt_pred) + print(f"Demo video is saved to [{save_video_path}]") + + writer.release() + + +def process( + model, + image_path, + output_path, + fps, + iters, +): + inputs, scale, padder = get_input_video_from_path(image_path) + outputs = interpolater(model, inputs, scale, padder, iters) + write(outputs, image_path, output_path, fps) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input", help="Input video.") + parser.add_argument("--ckpt", type=str, default="./pretrained_models/amt-g.pth", help="The pretrained model.") + parser.add_argument( + "--niters", + type=int, + default=1, + help="Iter of Interpolation. The number of frames will be double after per iter.", + ) + parser.add_argument("--output_path", type=str, default="samples", help="Output path.") + parser.add_argument("--fps", type=int, default=8, help="Frames rate of the output video.") + parser.add_argument("--folder", action="store_true", help="If the input is a folder, set this flag.") + args = parser.parse_args() + + times_frame = 2**args.niters + old_fps = args.fps + args.fps = args.fps * times_frame + print(f"Interpolation will turn {old_fps}fps video to {args.fps}fps video.") + args.input = os.path.expanduser(args.input) + args.ckpt = os.path.expanduser(args.ckpt) + args.folder = osp.splitext(args.input)[-1].lower() not in VID_EXT + args.ckpt = download_model(local_path=args.ckpt, url=hf_endpoint + "/lalala125/AMT/resolve/main/amt-g.pth") + return args + + +if __name__ == "__main__": + args = parse_args() + ckpt_path = args.ckpt + input_path = args.input + output_path = args.output_path + iters = int(args.niters) + fps = int(args.fps) + + model = load_model(ckpt_path) + + if args.folder: + for file in os.listdir(input_path): + if osp.splitext(file)[-1].lower() in VID_EXT: + vid_path = os.path.join(input_path, file) + process(model, vid_path, output_path, fps, iters) + else: + process(model, input_path, output_path, fps, iters) + + print("Interpolation is done.") + print(f"Output path: {output_path}") diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/__init__.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4db0516c70c506c454be74855adffa9ba686e0fe --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/__init__.py @@ -0,0 +1 @@ +from .amt_g import Model diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/amt_g.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/amt_g.py new file mode 100644 index 0000000000000000000000000000000000000000..84b28cbfabfd469be5ff47815babc49cd7ddbe12 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/amt_g.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn + +from .blocks.feat_enc import LargeEncoder +from .blocks.ifrnet import Encoder, InitDecoder, IntermediateDecoder, resize +from .blocks.multi_flow import MultiFlowDecoder, multi_flow_combine +from .blocks.raft import BasicUpdateBlock, BidirCorrBlock, coords_grid + + +class Model(nn.Module): + def __init__(self, corr_radius=3, corr_lvls=4, num_flows=5, channels=[84, 96, 112, 128], skip_channels=84): + super(Model, self).__init__() + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + + self.feat_encoder = LargeEncoder(output_dim=128, norm_fn="instance", dropout=0.0) + self.encoder = Encoder(channels, large=True) + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) + + self.update4 = self._get_updateblock(112, None) + self.update3_low = self._get_updateblock(96, 2.0) + self.update2_low = self._get_updateblock(84, 4.0) + + self.update3_high = self._get_updateblock(96, None) + self.update2_high = self._get_updateblock(84, None) + + self.comb_block = nn.Sequential( + nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + nn.PReLU(6 * self.num_flows), + nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), + ) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock( + cdim=cdim, + hidden_dim=192, + flow_dim=64, + corr_dim=256, + corr_dim2=192, + fc_dim=188, + scale_factor=scale_factor, + corr_levels=self.corr_levels, + radius=self.radius, + ) + + def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t1_scale = 1.0 / embt + t0_scale = 1.0 / (1.0 - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow + + def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coord = coords_grid(b, h // 8, w // 8, img0.device) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) + corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1) + + # residue update with lookup corr + delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + up_flow0_4 = up_flow0_4 + delta_flow0_4 + up_flow1_4 = up_flow1_4 + delta_flow1_4 + ft_3_ = ft_3_ + delta_ft_3_ + + ######################################### the 3rd decoder ######################################### + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + corr_3, flow_3 = self._corr_scale_lookup(corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2) + + # residue update with lookup corr + delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3) + delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) + up_flow0_3 = up_flow0_3 + delta_flow0_3 + up_flow1_3 = up_flow1_3 + delta_flow1_3 + ft_2_ = ft_2_ + delta_ft_2_ + + # residue update with lookup corr (hr) + corr_3 = resize(corr_3, scale_factor=2.0) + up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1) + delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3) + ft_2_ += delta_ft_2_ + up_flow0_3 += delta_up_flow_3[:, 0:2] + up_flow1_3 += delta_up_flow_3[:, 2:4] + + ######################################### the 2nd decoder ######################################### + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + corr_2, flow_2 = self._corr_scale_lookup(corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4) + + # residue update with lookup corr + delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2) + delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) + up_flow0_2 = up_flow0_2 + delta_flow0_2 + up_flow1_2 = up_flow1_2 + delta_flow1_2 + ft_1_ = ft_1_ + delta_ft_1_ + + # residue update with lookup corr (hr) + corr_2 = resize(corr_2, scale_factor=4.0) + up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1) + delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2) + ft_1_ += delta_ft_1_ + up_flow0_2 += delta_up_flow_2[:, 0:2] + up_flow1_2 += delta_up_flow_2[:, 2:4] + + ######################################### the 1st decoder ######################################### + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (1.0 / scale_factor) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (1.0 / scale_factor) + mask = resize(mask, scale_factor=(1.0 / scale_factor)) + img_res = resize(img_res, scale_factor=(1.0 / scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + if eval: + return { + "imgt_pred": imgt_pred, + } + else: + up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) + up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) + return { + "imgt_pred": imgt_pred, + "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + "ft_pred": [ft_1_, ft_2_, ft_3_], + } diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/__init__.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/feat_enc.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/feat_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..479833824b8b2da7e9e3ba05c84b0359b8c79c37 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/feat_enc.py @@ -0,0 +1,335 @@ +import torch +import torch.nn as nn + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(72, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class LargeEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(LargeEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(112, stride=2) + self.layer3 = self._make_layer(160, stride=2) + self.layer3_2 = self._make_layer(160, stride=1) + + # output convolution + self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer3_2(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/ifrnet.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/ifrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5719a040e102c36a417925e78f5acb4cf4402725 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/ifrnet.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tools.frame_interpolation.utils.flow_utils import warp + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), + nn.PReLU(out_channels), + ) + + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels), + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels), + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, : -self.side_channels, ...] + side_feat = out[:, -self.side_channels :, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, : -self.side_channels, ...] + side_feat = out[:, -self.side_channels :, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + + +class Encoder(nn.Module): + def __init__(self, channels, large=False): + super(Encoder, self).__init__() + self.channels = channels + prev_ch = 3 + for idx, ch in enumerate(channels, 1): + k = 7 if large and idx == 1 else 3 + p = 3 if k == 7 else 1 + self.register_module( + f"pyramid{idx}", nn.Sequential(convrelu(prev_ch, ch, k, 2, p), convrelu(ch, ch, 3, 1, 1)) + ) + prev_ch = ch + + def forward(self, in_x): + fs = [] + for idx in range(len(self.channels)): + out_x = getattr(self, f"pyramid{idx+1}")(in_x) + fs.append(out_x) + in_x = out_x + return fs + + +class InitDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch * 2 + 1, in_ch * 2), + ResBlock(in_ch * 2, skip_ch), + nn.ConvTranspose2d(in_ch * 2, out_ch + 4, 4, 2, 1, bias=True), + ) + + def forward(self, f0, f1, embt): + h, w = f0.shape[2:] + embt = embt.repeat(1, 1, h, w) + out = self.convblock(torch.cat([f0, f1, embt], 1)) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + return flow0, flow1, ft_ + + +class IntermediateDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch * 3 + 4, in_ch * 3), + ResBlock(in_ch * 3, skip_ch), + nn.ConvTranspose2d(in_ch * 3, out_ch + 4, 4, 2, 1, bias=True), + ) + + def forward(self, ft_, f0, f1, flow0_in, flow1_in): + f0_warp = warp(f0, flow0_in) + f1_warp = warp(f1, flow1_in) + f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) + out = self.convblock(f_in) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) + flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) + return flow0, flow1, ft_ diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/multi_flow.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/multi_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb96a9ef6bcee99627e7c844e45987bfb2d9308 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/multi_flow.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + +from tools.frame_interpolation.utils.flow_utils import warp + +from .ifrnet import ResBlock, convrelu, resize + + +def multi_flow_combine(comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None): + """ + A parallel implementation of multiple flow field warping + comb_block: An nn.Seqential object. + img shape: [b, c, h, w] + flow shape: [b, 2*num_flows, h, w] + mask (opt): + If 'mask' is None, the function conduct a simple average. + img_res (opt): + If 'img_res' is None, the function adds zero instead. + mean (opt): + If 'mean' is None, the function adds zero instead. + """ + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w) if mask is not None else None + img_res = img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w) if img_res is not None else 0 + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1) if mean is not None else 0 + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) + return imgt_pred + + +class MultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(MultiFlowDecoder, self).__init__() + self.num_flows = num_flows + self.convblock = nn.Sequential( + convrelu(in_ch * 3 + 4, in_ch * 3), + ResBlock(in_ch * 3, skip_ch), + nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True), + ) + + def forward(self, ft_, f0, f1, flow0, flow1): + n = self.num_flows + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) + delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2 * n, 2 * n, n, 3 * n], 1) + mask = torch.sigmoid(mask) + + flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0).repeat(1, self.num_flows, 1, 1) + flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0).repeat(1, self.num_flows, 1, 1) + + return flow0, flow1, mask, img_res diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/raft.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..1576889201c49614224450c9a223b871e8031f2d --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/networks/blocks/raft.py @@ -0,0 +1,213 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def bilinear_sampler(img, coords, mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij") + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, corr_levels=4, radius=3, scale_factor=None): + super(SmallUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) ** 2 + self.scale_factor = scale_factor + + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(corr_dim + flow_dim, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + + return delta_net, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__( + self, + cdim, + hidden_dim, + flow_dim, + corr_dim, + corr_dim2, + fc_dim, + corr_levels=4, + radius=3, + scale_factor=None, + out_num=1, + ): + super(BasicUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) ** 2 + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + return delta_net, delta_flow + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.corr_pyramid_T = [] + + corr = BidirCorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) + + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + corr_T = corr_T.reshape(batch * h2 * w2, dim, h1, w1) + + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + for _ in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_T = F.avg_pool2d(corr_T, 2, stride=2) + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + def __call__(self, coords0, coords1): + r = self.radius + coords0 = coords0.permute(0, 2, 3, 1) + coords1 = coords1.permute(0, 2, 3, 1) + assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" + batch, h1, w1, _ = coords0.shape + + out_pyramid = [] + out_pyramid_T = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + corr_T = self.corr_pyramid_T[i] + + dx = torch.linspace(-r, r, 2 * r + 1, device=coords0.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords0.device) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1) + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + + centroid_lvl_0 = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + centroid_lvl_1 = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + coords_lvl_0 = centroid_lvl_0 + delta_lvl + coords_lvl_1 = centroid_lvl_1 + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl_0) + corr_T = bilinear_sampler(corr_T, coords_lvl_1) + corr = corr.view(batch, h1, w1, -1) + corr_T = corr_T.view(batch, h1, w1, -1) + out_pyramid.append(corr) + out_pyramid_T.append(corr_T) + + out = torch.cat(out_pyramid, dim=-1) + out_T = torch.cat(out_pyramid_T, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/__init__.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/dist_utils.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d754d4fc7a6ed1a9bae246b2f895456218d815ea --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/dist_utils.py @@ -0,0 +1,48 @@ +import os + +import torch + + +def get_world_size(): + """Find OMPI world size without calling mpi functions + :rtype: int + """ + if os.environ.get("PMI_SIZE") is not None: + return int(os.environ.get("PMI_SIZE") or 1) + elif os.environ.get("OMPI_COMM_WORLD_SIZE") is not None: + return int(os.environ.get("OMPI_COMM_WORLD_SIZE") or 1) + else: + return torch.cuda.device_count() + + +def get_global_rank(): + """Find OMPI world rank without calling mpi functions + :rtype: int + """ + if os.environ.get("PMI_RANK") is not None: + return int(os.environ.get("PMI_RANK") or 0) + elif os.environ.get("OMPI_COMM_WORLD_RANK") is not None: + return int(os.environ.get("OMPI_COMM_WORLD_RANK") or 0) + else: + return 0 + + +def get_local_rank(): + """Find OMPI local rank without calling mpi functions + :rtype: int + """ + if os.environ.get("MPI_LOCALRANKID") is not None: + return int(os.environ.get("MPI_LOCALRANKID") or 0) + elif os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") is not None: + return int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") or 0) + else: + return 0 + + +def get_master_ip(): + if os.environ.get("AZ_BATCH_MASTER_NODE") is not None: + return os.environ.get("AZ_BATCH_MASTER_NODE").split(":")[0] + elif os.environ.get("AZ_BATCHAI_MPI_MASTER_NODE") is not None: + return os.environ.get("AZ_BATCHAI_MPI_MASTER_NODE") + else: + return "127.0.0.1" diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/flow_utils.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4edee465ab5e16459358c3c4c2a1ac20b468d90e --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/flow_utils.py @@ -0,0 +1,125 @@ +import numpy as np +import torch +import torch.nn.functional as F +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample(input=img, grid=grid_, mode="bilinear", padding_mode="border", align_corners=True) + return output + + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u) / np.pi + fk = (a + 1) / 2 * (ncols - 1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:, i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, "input flow must have three dimensions" + assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) diff --git a/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/utils.py b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..285a65fd454e034ce672dcea82d1449bc77ef953 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/frame_interpolation/utils/utils.py @@ -0,0 +1,314 @@ +import random +import re +import sys + +import numpy as np +import torch +import torch.nn.functional as F +from imageio import imread, imwrite +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class AverageMeter: + def __init__(self): + self.reset() + + def reset(self): + self.val = 0.0 + self.avg = 0.0 + self.sum = 0.0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class AverageMeterGroups: + def __init__(self) -> None: + self.meter_dict = dict() + + def update(self, dict, n=1): + for name, val in dict.items(): + if self.meter_dict.get(name) is None: + self.meter_dict[name] = AverageMeter() + self.meter_dict[name].update(val, n) + + def reset(self, name=None): + if name is None: + for v in self.meter_dict.values(): + v.reset() + else: + meter = self.meter_dict.get(name) + if meter is not None: + meter.reset() + + def avg(self, name): + meter = self.meter_dict.get(name) + if meter is not None: + return meter.avg + + +class InputPadder: + """Pads images such that dimensions are divisible by divisor""" + + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] + + def pad(self, *inputs): + if len(inputs) == 1: + return F.pad(inputs[0], self._pad, mode="replicate") + else: + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, *inputs): + if len(inputs) == 1: + return self._unpad(inputs[0]) + else: + return [self._unpad(x) for x in inputs] + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +def img2tensor(img): + if img.shape[-1] > 3: + img = img[:, :, :3] + return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 + + +def tensor2img(img_t): + return (img_t * 255.0).detach().squeeze(0).permute(1, 2, 0).cpu().numpy().clip(0, 255).astype(np.uint8) + + +def seed_all(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def read(file): + if file.endswith(".float3"): + return readFloat(file) + elif file.endswith(".flo"): + return readFlow(file) + elif file.endswith(".ppm"): + return readImage(file) + elif file.endswith(".pgm"): + return readImage(file) + elif file.endswith(".png"): + return readImage(file) + elif file.endswith(".jpg"): + return readImage(file) + elif file.endswith(".pfm"): + return readPFM(file)[0] + else: + raise Exception("don't know how to read %s" % file) + + +def write(file, data): + if file.endswith(".float3"): + return writeFloat(file, data) + elif file.endswith(".flo"): + return writeFlow(file, data) + elif file.endswith(".ppm"): + return writeImage(file, data) + elif file.endswith(".pgm"): + return writeImage(file, data) + elif file.endswith(".png"): + return writeImage(file, data) + elif file.endswith(".jpg"): + return writeImage(file, data) + elif file.endswith(".pfm"): + return writePFM(file, data) + else: + raise Exception("don't know how to write %s" % file) + + +def readPFM(file): + file = open(file, "rb") + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file.") + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + endian = "<" + scale = -scale + else: + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + + +def writePFM(file, image, scale=1): + file = open(file, "wb") + + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def readFlow(name): + if name.endswith(".pfm") or name.endswith(".PFM"): + return readPFM(name)[0][:, :, 0:2] + + f = open(name, "rb") + + header = f.read(4) + if header.decode("utf-8") != "PIEH": + raise Exception("Flow file header does not contain PIEH") + + width = np.fromfile(f, np.int32, 1).squeeze() + height = np.fromfile(f, np.int32, 1).squeeze() + + flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) + + return flow.astype(np.float32) + + +def readImage(name): + if name.endswith(".pfm") or name.endswith(".PFM"): + data = readPFM(name)[0] + if len(data.shape) == 3: + return data[:, :, 0:3] + else: + return data + return imread(name) + + +def writeImage(name, data): + if name.endswith(".pfm") or name.endswith(".PFM"): + return writePFM(name, data, 1) + return imwrite(name, data) + + +def writeFlow(name, flow): + f = open(name, "wb") + f.write("PIEH".encode("utf-8")) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + + +def readFloat(name): + f = open(name, "rb") + + if (f.readline().decode("utf-8")) != "float\n": + raise Exception("float file %s did not contain keyword" % name) + + dim = int(f.readline()) + + dims = [] + count = 1 + for i in range(0, dim): + d = int(f.readline()) + dims.append(d) + count *= d + + dims = list(reversed(dims)) + + data = np.fromfile(f, np.float32, count).reshape(dims) + if dim > 2: + data = np.transpose(data, (2, 1, 0)) + data = np.transpose(data, (1, 0, 2)) + + return data + + +def writeFloat(name, data): + f = open(name, "wb") + + dim = len(data.shape) + if dim > 3: + raise Exception("bad float file dimension: %d" % dim) + + f.write(("float\n").encode("ascii")) + f.write(("%d\n" % dim).encode("ascii")) + + if dim == 1: + f.write(("%d\n" % data.shape[0]).encode("ascii")) + else: + f.write(("%d\n" % data.shape[1]).encode("ascii")) + f.write(("%d\n" % data.shape[0]).encode("ascii")) + for i in range(2, dim): + f.write(("%d\n" % data.shape[i]).encode("ascii")) + + data = data.astype(np.float32) + if dim == 2: + data.tofile(f) + + else: + np.transpose(data, (2, 0, 1)).tofile(f) + + +def check_dim_and_resize(tensor_list): + shape_list = [] + for t in tensor_list: + shape_list.append(t.shape[2:]) + + if len(set(shape_list)) > 1: + desired_shape = shape_list[0] + print(f"Inconsistent size of input video frames. All frames will be resized to {desired_shape}") + + resize_tensor_list = [] + for t in tensor_list: + resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode="bilinear")) + + tensor_list = resize_tensor_list + + return tensor_list diff --git a/src/videogen_hub/pipelines/opensora/tools/scene_cut/README.md b/src/videogen_hub/pipelines/opensora/tools/scene_cut/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bb8c254384291ad8e6a1cadc1688f8ec22056a83 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scene_cut/README.md @@ -0,0 +1,56 @@ +# Scene Detection and Video Splitting + +- [Scene Detection and Video Splitting](#scene-detection-and-video-splitting) + - [Prepare Meta Files](#prepare-meta-files) + - [Scene Detection](#scene-detection) + - [Video Splitting](#video-splitting) + +In many cases, raw videos contain several scenes and are too long for training. Thus, it is essential to split them into shorter +clips based on scenes. Here, we provide code for scene detection and video splitting. + +## Prepare Meta Files +At this step, you should have a raw video dataset prepared. A meta file of the dataset information is needed for data processing. To create a meta file from a folder, run: + +```bash +python -m tools.datasets.convert video /path/to/video/folder --output /path/to/save/meta.csv +``` +This should output a `.csv` file with column `path`. + +If you already have a meta file for the videos and want to keep the information. +**Make sure** the meta file has column `id`, which is the id for each video, and the video is named as `{id}.mp4`. +The following command will add a new column `path` to the meta file. + +```bash +python tools/scene_cut/convert_id_to_path.py /path/to/meta.csv --folder_path /path/to/video/folder +``` +This should output +- `{prefix}_path-filtered.csv` with column `path` (broken videos filtered) +- `{prefix}_path_intact.csv` with column `path` and `intact` (`intact` indicating a video is intact or not) + + +## Scene Detection +The next step is to detect scenes in a video. +We use [`PySceneDetect`](https://github.com/Breakthrough/PySceneDetect) for this job. +**Make sure** the input meta file has column `path`, which is the path of a video. + +```bash +python tools/scene_cut/scene_detect.py /path/to/meta.csv +``` +The output is `{prefix}_timestamp.csv` with column `timestamp`. Each cell in column `timestamp` is a list of tuples, +with each tuple indicating the start and end timestamp of a scene +(e.g., `[('00:00:01.234', '00:00:02.345'), ('00:00:03.456', '00:00:04.567')]`). + +## Video Splitting +After obtaining timestamps for scenes, we conduct video splitting (cutting). +**Make sure** the meta file contains column `timestamp`. + +```bash +python tools/scene_cut/cut.py /path/to/meta.csv --save_dir /path/to/output/dir +``` + +This will save video clips to `/path/to/output/dir`. The video clips are named as `{video_id}_scene-{scene_id}.mp4` + +To create a new meta file for the generated clips, run: +```bash +python -m tools.datasets.convert video /path/to/video/folder --output /path/to/save/meta.csv +``` diff --git a/src/videogen_hub/pipelines/opensora/tools/scene_cut/__init__.py b/src/videogen_hub/pipelines/opensora/tools/scene_cut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/scene_cut/convert_id_to_path.py b/src/videogen_hub/pipelines/opensora/tools/scene_cut/convert_id_to_path.py new file mode 100644 index 0000000000000000000000000000000000000000..eb7b1cb4e27e1de54738545e880544fac55a889f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scene_cut/convert_id_to_path.py @@ -0,0 +1,128 @@ +import os + +import argparse +import json +from functools import partial + +import numpy as np +import pandas as pd +from pandarallel import pandarallel +import cv2 +from mmengine.logging import print_log +from moviepy.editor import VideoFileClip +from tqdm import tqdm + +tqdm.pandas() + + +def is_intact_video(video_path, mode="moviepy", verbose=False, logger=None): + if not os.path.exists(video_path): + if verbose: + print_log(f"Could not find '{video_path}'", logger=logger) + return False + + if mode == "moviepy": + try: + VideoFileClip(video_path) + if verbose: + print_log(f"The video file '{video_path}' is intact.", logger=logger) + return True + except Exception as e: + if verbose: + print_log(f"Error: {e}", logger=logger) + print_log(f"The video file '{video_path}' is not intact.", logger=logger) + return False + elif mode == "cv2": + try: + cap = cv2.VideoCapture(video_path) + if cap.isOpened(): + if verbose: + print_log(f"The video file '{video_path}' is intact.", logger=logger) + return True + except Exception as e: + if verbose: + print_log(f"Error: {e}", logger=logger) + print_log(f"The video file '{video_path}' is not intact.", logger=logger) + return False + else: + raise ValueError + + +def has_downloaded_success(json_path): + if not os.path.exists(json_path): + return False + + try: + with open(json_path, "r") as f: + data = json.load(f) + if "success" not in data or isinstance(data["success"], bool) is False or data["success"] is False: + return False + except Exception: + return False + + return True + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("meta_path", type=str) + parser.add_argument("--folder_path", type=str, required=True) + parser.add_argument("--mode", type=str, default=None) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + meta_path = args.meta_path + folder_path = args.folder_path + mode = args.mode + + def is_intact(row, mode=None): + video_id = row["id"] + video_path = os.path.join(folder_path, f"{video_id}.mp4") + row["path"] = video_path + + if mode == ".mp4": + if is_intact_video(video_path): + return True, video_path + return False, video_path + elif mode == ".json": + # json_path = os.path.join(root_raw, f"data/{split}/{video_id}.json") + json_path = os.path.join(folder_path, f"{video_id}.json") + if has_downloaded_success(json_path): + return True, video_path + return False, video_path + elif mode is None: + return True, video_path + else: + raise ValueError + + meta_dirpath = os.path.dirname(meta_path) + meta_fname = os.path.basename(meta_path) + wo_ext, ext = os.path.splitext(meta_fname) + + pandarallel.initialize(progress_bar=True) + is_intact_partial = partial(is_intact, mode=mode) + + meta = pd.read_csv(meta_path) + ret = meta.parallel_apply(is_intact_partial, axis=1) + intact, paths = list(zip(*ret)) + + meta["intact"] = intact + meta["path"] = paths + out_path = os.path.join(meta_dirpath, f"{wo_ext}_path_intact.csv") + meta.to_csv(out_path, index=False) + print(f"New meta (shape={meta.shape}) with intact info saved to '{out_path}'") + + meta_format = meta[np.array(intact)] + meta_format.drop("intact", axis=1, inplace=True) + out_path = os.path.join(meta_dirpath, f"{wo_ext}_path-filtered.csv") + meta_format.to_csv(out_path, index=False) + print(f"New meta (shape={meta_format.shape}) with format info saved to '{out_path}'") + + +if __name__ == "__main__": + main() diff --git a/src/videogen_hub/pipelines/opensora/tools/scene_cut/cut.py b/src/videogen_hub/pipelines/opensora/tools/scene_cut/cut.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ecbe00b87c905d269c629bbd1b4fe7ada1e5ae --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scene_cut/cut.py @@ -0,0 +1,164 @@ +import argparse +import os +import subprocess +import time +from functools import partial + +import pandas as pd +from imageio_ffmpeg import get_ffmpeg_exe +from mmengine.logging import MMLogger, print_log +from pandarallel import pandarallel +from scenedetect import FrameTimecode +from tqdm import tqdm + +tqdm.pandas() + + +def process_single_row(row, args, log_name=None): + video_path = row["path"] + + logger = None + if log_name is not None: + logger = MMLogger.get_instance(log_name) + + # check mp4 integrity + # if not is_intact_video(video_path, logger=logger): + # return False + + timestamp = row["timestamp"] + if not (timestamp.startswith("[") and timestamp.endswith("]")): + return False + scene_list = eval(timestamp) + scene_list = [(FrameTimecode(s, fps=1), FrameTimecode(t, fps=1)) for s, t in scene_list] + split_video( + video_path, + scene_list, + save_dir=args.save_dir, + min_seconds=args.min_seconds, + max_seconds=args.max_seconds, + target_fps=args.target_fps, + shorter_size=args.shorter_size, + logger=logger, + ) + + +def split_video( + video_path, + scene_list, + save_dir, + min_seconds=2.0, + max_seconds=15.0, + target_fps=30, + shorter_size=720, + verbose=False, + logger=None, +): + """ + scenes shorter than min_seconds will be ignored; + scenes longer than max_seconds will be cut to save the beginning max_seconds. + Currently, the saved file name pattern is f'{fname}_scene-{idx}'.mp4 + + Args: + scene_list (List[Tuple[FrameTimecode, FrameTimecode]]): each element is (s, t): start and end of a scene. + min_seconds (float | None) + max_seconds (float | None) + target_fps (int | None) + shorter_size (int | None) + """ + FFMPEG_PATH = get_ffmpeg_exe() + + save_path_list = [] + for idx, scene in enumerate(scene_list): + s, t = scene # FrameTimecode + if min_seconds is not None: + if (t - s).get_seconds() < min_seconds: + continue + + duration = t - s + if max_seconds is not None: + fps = s.framerate + max_duration = FrameTimecode(timecode="00:00:00", fps=fps) + max_duration.frame_num = round(fps * max_seconds) + duration = min(max_duration, duration) + + # save path + fname = os.path.basename(video_path) + fname_wo_ext = os.path.splitext(fname)[0] + # TODO: fname pattern + save_path = os.path.join(save_dir, f"{fname_wo_ext}_scene-{idx}.mp4") + + # ffmpeg cmd + cmd = [FFMPEG_PATH] + + # Only show ffmpeg output for the first call, which will display any + # errors if it fails, and then break the loop. We only show error messages + # for the remaining calls. + # cmd += ['-v', 'error'] + + # clip to cut + # -ss after -i is very slow; put -ss before -i + cmd += ["-nostdin", "-y", "-ss", str(s.get_seconds()), "-i", video_path, "-t", str(duration.get_seconds())] + + # target fps + if target_fps is not None: + cmd += ["-r", f"{target_fps}"] + + # aspect ratio + if shorter_size is not None: + cmd += ["-vf", f"scale='if(gt(iw,ih),-2,{shorter_size})':'if(gt(iw,ih),{shorter_size},-2)'"] + # cmd += ['-vf', f"scale='if(gt(iw,ih),{shorter_size},trunc(ow/a/2)*2)':-2"] + + cmd += ["-map", "0", save_path] + + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + stdout, stderr = proc.communicate() + # stdout = stdout.decode("utf-8") + # print_log(stdout, logger=logger) + + save_path_list.append(video_path) + if verbose: + print_log(f"Video clip saved to '{save_path}'", logger=logger) + + return save_path_list + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("meta_path", type=str) + parser.add_argument("--save_dir", type=str) + parser.add_argument("--min_seconds", type=float, default=None, + help='if not None, clip shorter than min_seconds is ignored') + parser.add_argument("--max_seconds", type=float, default=None, + help='if not None, clip longer than max_seconds is truncated') + parser.add_argument("--target_fps", type=int, default=30, help='target fps of clips') + parser.add_argument("--shorter_size", type=int, default=720, help='resize the shorter size by keeping ratio') + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + save_dir = args.save_dir + os.makedirs(save_dir, exist_ok=True) + + # create logger + log_dir = os.path.dirname(save_dir) + log_name = os.path.basename(save_dir) + timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time())) + log_path = os.path.join(log_dir, f"{log_name}_{timestamp}.log") + logger = MMLogger.get_instance(log_name, log_file=log_path) + # logger = None + + # initialize pandarallel + pandarallel.initialize(progress_bar=True) + process_single_row_partial = partial(process_single_row, args=args, log_name=log_name) + + # process + meta = pd.read_csv(args.meta_path) + meta.parallel_apply(process_single_row_partial, axis=1) + + +if __name__ == "__main__": + main() diff --git a/src/videogen_hub/pipelines/opensora/tools/scene_cut/scene_detect.py b/src/videogen_hub/pipelines/opensora/tools/scene_cut/scene_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..eb7b003b5ebf932840fa1f038d0c3050273e81b3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scene_cut/scene_detect.py @@ -0,0 +1,62 @@ +import argparse +import os + +import numpy as np +import pandas as pd +from pandarallel import pandarallel +from scenedetect import AdaptiveDetector, detect +from tqdm import tqdm + +tqdm.pandas() + + +def process_single_row(row): + # windows + # from scenedetect import detect, ContentDetector, AdaptiveDetector + + video_path = row["path"] + + detector = AdaptiveDetector( + adaptive_threshold=3.0, + # luma_only=True, + ) + # detector = ContentDetector() + # TODO: catch error here + try: + scene_list = detect(video_path, detector, start_in_scene=True) + timestamp = [(s.get_timecode(), t.get_timecode()) for s, t in scene_list] + return True, str(timestamp) + except Exception as e: + print(f"Video '{video_path}' with error {e}") + return False, "" + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("meta_path", type=str) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + meta_path = args.meta_path + + pandarallel.initialize(progress_bar=True) + + meta = pd.read_csv(meta_path) + ret = meta.parallel_apply(process_single_row, axis=1) + + succ, timestamps = list(zip(*ret)) + meta["timestamp"] = timestamps + meta = meta[np.array(succ)] + + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_timestamp{ext}" + meta.to_csv(out_path, index=False) + print(f"New meta (shape={meta.shape}) with timestamp saved to '{out_path}'.") + + +if __name__ == "__main__": + main() diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/README.md b/src/videogen_hub/pipelines/opensora/tools/scoring/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a944d5cf79b993245a0b77108f4a83ce5ff2e8ac --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/README.md @@ -0,0 +1,100 @@ +# Scoring and Filtering + +- [Scoring and Filtering](#scoring-and-filtering) + - [Aesthetic Score](#aesthetic-score) + - [Optical Flow Score](#optical-flow-score) + - [OCR](#ocr) + - [Matching Score](#matching-score) + - [Filtering](#filtering) + +## Aesthetic Score + +To evaluate the aesthetic quality of videos, we use the scoring model from [CLIP+MLP Aesthetic Score Predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor). This model is trained on 176K SAC (Simulacra Aesthetic Captions) pairs, 15K LAION-Logos (Logos) pairs, and 250K AVA (The Aesthetic Visual Analysis) image-text pairs. + +The aesthetic score is between 1 and 10, where 5.5 can be considered as the threshold for fair aesthetics, and 6.5 for high aesthetics. Good text-to-image models can achieve a score of 7.0 or higher. + +For videos, we extract the first, last, and the middle frames for evaluation. The script also supports images as input. +The throughput of our code is ~1K videos/s on a single H800 GPU. It also supports running on multiple GPUs for further acceleration. + +First, install the required packages and download the scoring model to `./pretrained_models/aesthetic.pth`. +```bash +# pip install +pip install git+https://github.com/openai/CLIP.git +pip install decord + +# get pretrained model +wget https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac+logos+ava1-l14-linearMSE.pth -O pretrained_models/aesthetic.pth +``` + +Then, run the following command. **Make sure** the meta file has column `path` (path to the sample). +```bash +torchrun --nproc_per_node 8 -m tools.scoring.aesthetic.inference /path/to/meta.csv --bs 1024 --num_workers 16 +``` +This will generate multiple part files, each corresponding to a node . Run `python -m tools.datasets.datautil /path/to/meta_aes_part*.csv --output /path/to/meta_aes.csv` to merge them. + +## Optical Flow Score + +Optical flow scores are used to assess the motion of a video. Higher optical flow scores indicate larger movement. +We use the [UniMatch](https://github.com/autonomousvision/unimatch) model for this task. + +First, download the pretrained model to `./pretrained_model/unimatch/` +```bash +wget https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth -P ./pretrained_models/unimatch/ +``` + +Then, run the following command. **Make sure** the meta file has column `path` (path to the sample). +```bash +torchrun --standalone --nproc_per_node 8 tools/scoring/optical_flow/inference.py /path/to/meta.csv +``` + +This should output `/path/to/meta_flow.csv` with column `flow`. + +## OCR +Some videos are of dense text scenes like news broadcast and advertisement, which are not desired for training. +We apply Optical Character Recognition (OCR) to detect texts and drop samples with dense texts. Here, we use +the [DBNet++](https://arxiv.org/abs/2202.10304) model implemented by [MMOCR](https://github.com/open-mmlab/mmocr/). + +First, install [MMOCR](https://mmocr.readthedocs.io/en/dev-1.x/get_started/install.html). +For reference, we install packages of these versions. +``` +torch==2.0.1 +mmcv==2.0.1 +mmdet==3.1.0 +mmocr==1.0.1 +``` + +Then, run the following command. **Make sure** the meta file has column `path` (path to the sample). +```bash +torchrun --standalone --nproc_per_node 8 tools/scoring/ocr/inference.py /path/to/meta.csv +``` +This should output `/path/to/meta_ocr.csv` with column `ocr`, indicating the number of text regions with detection confidence > 0.3. + + +## Matching Score + +Matching scores are calculated to evaluate the alignment between an image/video and its caption. +Here, we use the [CLIP](https://github.com/openai/CLIP) model, which is trained on image-text pairs. +We simply use the cosine similarity as the matching score. +For videos, we extract the middle frame and compare it with the caption. + +First, install OpenAI CLIP. +```bash +pip install git+https://github.com/openai/CLIP.git +``` + +Then, run the following command. **Make sure** the meta file has column `path` (path to the sample) and `text` (caption of the sample). + +```bash +torchrun --standalone --nproc_per_node 8 tools/scoring/matching/inference.py /path/to/meta.csv +``` + +This should output `/path/to/meta_match.csv` with column `match`. Higher matching scores indicate better image-text/video-text alignment. + + +## Filtering +Once scores are obtained, it is simple to filter samples based on these scores. Here is an example to remove +samples of aesthetic score < 5.0. +``` +python -m tools.datasets.datautil /path/to/meta.csv --aesmin 5.0 +``` +This should output `/path/to/meta_aesmin5.0.csv` with column `aes` >= 5.0 \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/__init__.py b/src/videogen_hub/pipelines/opensora/tools/scoring/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/aesthetic/__init__.py b/src/videogen_hub/pipelines/opensora/tools/scoring/aesthetic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/aesthetic/inference.py b/src/videogen_hub/pipelines/opensora/tools/scoring/aesthetic/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a527859c3eb238195378297f6c7772851b579b75 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/aesthetic/inference.py @@ -0,0 +1,168 @@ +# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py +import argparse +from datetime import timedelta + +import clip +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from colossalai.utils import set_seed +from einops import rearrange +from PIL import Image +from torchvision.datasets.folder import pil_loader +from tqdm import tqdm + +from tools.datasets.utils import extract_frames, is_video + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +NUM_FRAMES_POINTS = { + 1: (0.5,), + 2: (0.25, 0.5), + 3: (0.1, 0.5, 0.9), +} + + +class VideoTextDataset(torch.utils.data.Dataset): + def __init__(self, csv_path, transform=None, num_frames=3): + self.csv_path = csv_path + self.data = pd.read_csv(csv_path) + self.transform = transform + self.points = NUM_FRAMES_POINTS[num_frames] + + def getitem(self, index): + sample = self.data.iloc[index] + path = sample["path"] + if not is_video(path): + images = [pil_loader(path)] + else: + num_frames = None + if "num_frames" in sample: + num_frames = sample["num_frames"] + images = extract_frames(sample["path"], points=self.points, backend="opencv", num_frames=num_frames) + images = [self.transform(img) for img in images] + images = torch.stack(images) + ret = dict(index=index, images=images) + return ret + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.getitem(index) + + +class MLP(nn.Module): + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.layers = nn.Sequential( + nn.Linear(self.input_size, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + def forward(self, x): + return self.layers(x) + + +class AestheticScorer(nn.Module): + def __init__(self, input_size, device): + super().__init__() + self.mlp = MLP(input_size) + self.mlp.load_state_dict(torch.load("pretrained_models/aesthetic.pth")) + self.clip, self.preprocess = clip.load("ViT-L/14", device=device) + + self.eval() + self.to(device) + + def forward(self, x): + image_features = self.clip.encode_image(x) + image_features = F.normalize(image_features, p=2, dim=-1).float() + return self.mlp(image_features) + + +@torch.inference_mode() +def main(args): + dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + set_seed(1024) + rank = dist.get_rank() + world_size = dist.get_world_size() + + output_file = args.input.replace(".csv", f"_aes_part{rank}.csv") + + # build model + device = "cuda" if torch.cuda.is_available() else "cpu" + model = AestheticScorer(768, device) + preprocess = model.preprocess + + # build dataset + dataset = VideoTextDataset(args.input, transform=preprocess, num_frames=args.num_frames) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dataset, num_replicas=world_size, rank=rank, shuffle=False + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=args.bs, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, + ) + + # compute aesthetic scores + dataset.data["aes"] = np.nan + + with tqdm(dataloader, position=rank, desc=f"Data Parallel Rank {rank}") as t: + for idx, batch in enumerate(t): + image_indices = batch["index"] + images = batch["images"].to(device, non_blocking=True) + B = images.shape[0] + images = rearrange(images, "b p c h w -> (b p) c h w") + + # compute score + scores = model(images) + scores = rearrange(scores, "(b p) 1 -> b p", b=B) + scores = scores.mean(dim=1) + scores_np = scores.to(torch.float32).cpu().numpy() + + # assign the score + dataset.data.loc[image_indices, "aes"] = scores_np + + # wait for all ranks to finish data processing + dist.barrier() + + # exclude rows whose aes is nan and save file + dataset.data = dataset.data[dataset.data["aes"] > 0] + dataset.data.to_csv(output_file, index=False) + print(f"New meta with aesthetic scores saved to '{output_file}'.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, help="Path to the input CSV file") + parser.add_argument("--bs", type=int, default=1024, help="Batch size") + parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") + parser.add_argument("--accumulate", type=int, default=1, help="batch to accumulate") + parser.add_argument("--prefetch_factor", type=int, default=2, help="Prefetch factor") + parser.add_argument("--num_frames", type=int, default=3, help="Number of frames to extract") + args = parser.parse_args() + + main(args) diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/matching/__init__.py b/src/videogen_hub/pipelines/opensora/tools/scoring/matching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/matching/inference.py b/src/videogen_hub/pipelines/opensora/tools/scoring/matching/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..7bedef1cfe96f4baa95bce7dee109739aaab9a90 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/matching/inference.py @@ -0,0 +1,127 @@ +import argparse +import os + +import clip +import colossalai +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.utils.data import DataLoader, DistributedSampler +from torchvision.datasets.folder import pil_loader +from tqdm import tqdm + +from tools.datasets.utils import extract_frames, is_video + + +class VideoTextDataset(torch.utils.data.Dataset): + def __init__(self, meta_path, transform): + self.meta_path = meta_path + self.meta = pd.read_csv(meta_path) + self.transform = transform + + def __getitem__(self, index): + row = self.meta.iloc[index] + path = row["path"] + + if is_video(path): + img = extract_frames(path, points=[0.5], backend="opencv")[0] + else: + img = pil_loader(path) + + img = self.transform(img) + + text = row["text"] + text = clip.tokenize(text, truncate=True).squeeze() + + return img, text, index + + def __len__(self): + return len(self.meta) + + +def merge_scores(gathered_list: list, meta: pd.DataFrame): + # reorder + indices_list = list(map(lambda x: x[0], gathered_list)) + scores_list = list(map(lambda x: x[1], gathered_list)) + flat_indices = [] + for x in zip(*indices_list): + flat_indices.extend(x) + flat_scores = [] + for x in zip(*scores_list): + flat_scores.extend(x) + flat_indices = np.array(flat_indices) + flat_scores = np.array(flat_scores) + # filter duplicates + unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) + meta.loc[unique_indices, "match"] = flat_scores[unique_indices_idx] + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("meta_path", type=str, help="Path to the input CSV file") + parser.add_argument("--bs", type=int, default=16, help="Batch size") + parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") + args = parser.parse_args() + return args + + +def main(): + colossalai.launch_from_torch({}) + args = parse_args() + + meta_path = args.meta_path + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_match{ext}" + + # build model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model, preprocess = clip.load("ViT-L/14", device=device) + logit_scale = model.logit_scale.exp().item() + + # build dataset + dataset = VideoTextDataset(meta_path=meta_path, transform=preprocess) + dataloader = DataLoader( + dataset, + batch_size=args.bs, + num_workers=args.num_workers, + sampler=DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + drop_last=False, + ), + ) + + # compute scores + dataset.meta["match"] = np.nan + indices_list = [] + scores_list = [] + model.eval() + for imgs, text, indices in tqdm(dataloader, disable=dist.get_rank() != 0): + imgs = imgs.to(device) + text = text.to(device) + + with torch.no_grad(): + feat_img = model.encode_image(imgs) + feat_text = model.encode_text(text) + + feat_img = F.normalize(feat_img, dim=1) + feat_text = F.normalize(feat_text, dim=1) + clip_scores = logit_scale * (feat_img * feat_text).sum(dim=1) + clip_scores = clip_scores.cpu().tolist() + indices_list.extend(indices) + scores_list.extend(clip_scores) + + gathered_list = [None] * dist.get_world_size() + dist.all_gather_object(gathered_list, (indices_list, scores_list)) + if dist.get_rank() == 0: + merge_scores(gathered_list, dataset.meta) + dataset.meta.to_csv(out_path, index=False) + print(f"New meta with matching scores saved to '{out_path}'.") + + +if __name__ == "__main__": + main() diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/ocr/__init__.py b/src/videogen_hub/pipelines/opensora/tools/scoring/ocr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/ocr/dbnetpp.py b/src/videogen_hub/pipelines/opensora/tools/scoring/ocr/dbnetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..e313fd4a5fc9ed8c073dd879a849b078966366f0 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/ocr/dbnetpp.py @@ -0,0 +1,64 @@ +model = dict( + type='DBNet', + backbone=dict( + type='CLIPResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + style='pytorch', + dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), + # init_cfg=dict( + # type='Pretrained', + # checkpoint='https://download.openmmlab.com/mmocr/backbone/resnet50-oclip-7ba0c533.pth'), + stage_with_dcn=(False, True, True, True), + ), + neck=dict( + type='FPNC', + in_channels=[256, 512, 1024, 2048], + lateral_channels=256, + asf_cfg=dict(attention_type='ScaleChannelSpatial'), + ), + det_head=dict( + type='DBHead', + in_channels=256, + module_loss=dict(type='DBModuleLoss'), + postprocessor=dict( + type='DBPostprocessor', text_repr_type='quad', + epsilon_ratio=0.002, + ), + ), + data_preprocessor=dict( + type='TextDetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + ), + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmocr/textdet/dbnetpp/' + 'dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015/' + 'dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015_20221101_124139-4ecb39ac.pth', + ) +) + +test_pipeline = [ + # dict(type='LoadImageFromFile', color_type='color_ignore_orientation'), + dict(type='Resize', scale=(4068, 1024), keep_ratio=True), + dict( + type='PackTextDetInputs', + # meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'), + meta_keys=('img_shape', 'scale_factor'), + ) +] + +# Visualization +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='TextDetLocalVisualizer', + name='visualizer', + vis_backends=vis_backends, +) diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/ocr/inference.py b/src/videogen_hub/pipelines/opensora/tools/scoring/ocr/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3b1560534aec2385f0406dbd1c1cdbaf48d27c50 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/ocr/inference.py @@ -0,0 +1,150 @@ +import argparse +import os + +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torchvision.transforms import Resize, CenterCrop, Compose +from torch.utils.data import DataLoader, DistributedSampler +from torchvision.datasets.folder import pil_loader +from tqdm import tqdm + +import colossalai +from mmengine import Config +from mmengine.registry import DefaultScope +from mmengine.dataset import Compose, default_collate +from mmocr.registry import MODELS +from mmocr.datasets import PackTextDetInputs + +from tools.datasets.utils import extract_frames, is_video + + +def merge_scores(gathered_list: list, meta: pd.DataFrame): + # reorder + indices_list = list(map(lambda x: x[0], gathered_list)) + scores_list = list(map(lambda x: x[1], gathered_list)) + flat_indices = [] + for x in zip(*indices_list): + flat_indices.extend(x) + flat_scores = [] + for x in zip(*scores_list): + flat_scores.extend(x) + flat_indices = np.array(flat_indices) + flat_scores = np.array(flat_scores) + # filter duplicates + unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) + meta.loc[unique_indices, "ocr"] = flat_scores[unique_indices_idx] + + +class VideoTextDataset(torch.utils.data.Dataset): + def __init__(self, meta_path, transform): + self.meta_path = meta_path + self.meta = pd.read_csv(meta_path) + self.transform = transform + self.transform = Compose([ + Resize(1024), + CenterCrop(1024), + ]) + self.formatting = PackTextDetInputs(meta_keys=['scale_factor']) + + def __getitem__(self, index): + row = self.meta.iloc[index] + path = row["path"] + + if is_video(path): + img = extract_frames(path, frame_inds=[10], backend="opencv")[0] + else: + img = pil_loader(path) + + img = self.transform(img) + img_array = np.array(img)[:, :, ::-1].copy() # bgr + results = { + 'img': img_array, + 'scale_factor': 1.0, + # 'img_shape': img_array.shape[-2], + # 'ori_shape': img_array.shape[-2], + } + results = self.formatting(results) + results['index'] = index + + return results + + def __len__(self): + return len(self.meta) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("meta_path", type=str, help="Path to the input CSV file") + parser.add_argument("--bs", type=int, default=16, help="Batch size") + parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + cfg = Config.fromfile('./tools/scoring/ocr/dbnetpp.py') + + meta_path = args.meta_path + + colossalai.launch_from_torch({}) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + DefaultScope.get_instance('ocr', scope_name='mmocr') # use mmocr Registry as default + + # build model + model = MODELS.build(cfg.model) + model.init_weights() + model.to(device) # set data_preprocessor._device + print('==> Model built.') + + # build dataset + transform = Compose(cfg.test_pipeline) + dataset = VideoTextDataset(meta_path=meta_path, transform=transform) + dataloader = DataLoader( + dataset, + batch_size=args.bs, + num_workers=args.num_workers, + sampler=DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + drop_last=False, + ), + collate_fn=default_collate, + ) + print('==> Dataloader built.') + + # compute scores + dataset.meta["ocr"] = np.nan + indices_list = [] + scores_list = [] + model.eval() + for data in tqdm(dataloader, disable=dist.get_rank() != 0): + indices_i = data['index'] + indices_list.extend(indices_i.tolist()) + del data['index'] + + pred = model.test_step(data) # this line will cast data to device + + num_texts_i = [(x.pred_instances.scores > 0.3).sum().item() for x in pred] + scores_list.extend(num_texts_i) + + gathered_list = [None] * dist.get_world_size() + dist.all_gather_object(gathered_list, (indices_list, scores_list)) + + if dist.get_rank() == 0: + merge_scores(gathered_list, dataset.meta) + + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_ocr{ext}" + dataset.meta.to_csv(out_path, index=False) + print(f"New meta (shape={dataset.meta.shape}) with ocr results saved to '{out_path}'.") + + +if __name__ == '__main__': + main() diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/__init__.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/inference.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..170b0766f582f874cf970ff28b3cee216dcea1cb --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/inference.py @@ -0,0 +1,152 @@ +import argparse +import os + +import colossalai +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +import torch.nn.functional as F +from einops import rearrange +from torch.utils.data import DataLoader, DistributedSampler +from torchvision.transforms.functional import pil_to_tensor +from tqdm import tqdm + +from tools.datasets.utils import extract_frames + +from .unimatch import UniMatch + + +def merge_scores(gathered_list: list, meta: pd.DataFrame): + # reorder + indices_list = list(map(lambda x: x[0], gathered_list)) + flow_scores_list = list(map(lambda x: x[1], gathered_list)) + flat_indices = [] + for x in zip(*indices_list): + flat_indices.extend(x) + flat_flow_scores = [] + for x in zip(*flow_scores_list): + flat_flow_scores.extend(x) + flat_indices = np.array(flat_indices) + flat_flow_scores = np.array(flat_flow_scores) + # filter duplicates + unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) + meta.loc[unique_indices, "flow"] = flat_flow_scores[unique_indices_idx] + + +class VideoTextDataset(torch.utils.data.Dataset): + def __init__(self, meta_path, frame_inds=[0, 10, 20, 30]): + self.meta_path = meta_path + self.meta = pd.read_csv(meta_path) + self.frame_inds = frame_inds + + def __getitem__(self, index): + row = self.meta.iloc[index] + images = extract_frames(row["path"], frame_inds=self.frame_inds, backend="opencv") + + # transform + images = torch.stack([pil_to_tensor(x) for x in images]) # shape: [N, C, H, W]; dtype: torch.uint8 + images = images.float() + H, W = images.shape[-2:] + if H > W: + images = rearrange(images, "N C H W -> N C W H") + images = F.interpolate(images, size=(320, 576), mode="bilinear", align_corners=True) + + return images, index + + def __len__(self): + return len(self.meta) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("meta_path", type=str, help="Path to the input CSV file") + parser.add_argument("--bs", type=int, default=4, help="Batch size") + parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") + args = parser.parse_args() + return args + + +def main(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + colossalai.launch_from_torch({}) + args = parse_args() + + meta_path = args.meta_path + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_flow{ext}" + + # build model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = UniMatch( + feature_channels=128, + num_scales=2, + upsample_factor=4, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + reg_refine=True, + task="flow", + ).eval() + ckpt = torch.load("./pretrained_models/unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth") + model.load_state_dict(ckpt["model"]) + model = model.to(device) + # model = torch.nn.DataParallel(model) + + # build dataset + dataset = VideoTextDataset(meta_path=meta_path, frame_inds=[0, 10, 20, 30]) + dataloader = DataLoader( + dataset, + batch_size=args.bs, + num_workers=args.num_workers, + sampler=DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + drop_last=False, + ), + ) + + # compute optical flow scores + dataset.meta["flow"] = np.nan + indices_list = [] + flow_scores_list = [] + for images, indices in tqdm(dataloader, disable=dist.get_rank() != 0): + images = images.to(device) + B = images.shape[0] + + batch_0 = rearrange(images[:, :-1], "B N C H W -> (B N) C H W").contiguous() + batch_1 = rearrange(images[:, 1:], "B N C H W -> (B N) C H W").contiguous() + + with torch.no_grad(): + res = model( + batch_0, + batch_1, + attn_type="swin", + attn_splits_list=[2, 8], + corr_radius_list=[-1, 4], + prop_radius_list=[-1, 1], + num_reg_refine=6, + task="flow", + pred_bidir_flow=False, + ) + flow_maps = res["flow_preds"][-1].cpu() # [B * (N-1), 2, H, W] + flow_maps = rearrange(flow_maps, "(B N) C H W -> B N H W C", B=B) + flow_scores = flow_maps.abs().mean(dim=[1, 2, 3, 4]) + flow_scores = flow_scores.tolist() + + indices_list.extend(indices) + flow_scores_list.extend(flow_scores) + + gathered_list = [None] * dist.get_world_size() + dist.all_gather_object(gathered_list, (indices_list, flow_scores_list)) + if dist.get_rank() == 0: + merge_scores(gathered_list, dataset.meta) + dataset.meta.to_csv(out_path, index=False) + print(f"New meta with optical flow scores saved to '{out_path}'.") + + +if __name__ == "__main__": + main() diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/__init__.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f4eb2f58e4f32026f301c80331f536918fae7a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/__init__.py @@ -0,0 +1 @@ +from .unimatch import UniMatch diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/attention.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..23fb9048a07fcbd5228f42de4cca0a0f5ed9b60b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/attention.py @@ -0,0 +1,280 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import merge_splits, merge_splits_1d, split_feature, split_feature_1d + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def single_head_full_attention_1d( + q, + k, + v, + h=None, + w=None, +): + # q, k, v: [B, L, C] + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c**0.5 + + scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W] + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C] + + return out + + +def single_head_split_window_attention( + q, + k, + v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, +): + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c**0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = ( + torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) / scale_factor + ) # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits( + out.view(b_new, h // num_splits, w // num_splits, c), num_splits=num_splits, channel_last=True + ) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +def single_head_split_window_attention_1d( + q, + k, + v, + relative_position_bias=None, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, +): + # q, k, v: [B, L, C] + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * h + + window_size_w = w // num_splits + + q = q.view(b * h, w, c) # [B*H, W, C] + k = k.view(b * h, w, c) + v = v.view(b * h, w, c) + + scale_factor = c**0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=-shift_size_w, dims=1) + k = torch.roll(k, shifts=-shift_size_w, dims=1) + v = torch.roll(v, shifts=-shift_size_w, dims=1) + + q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C] + k = split_feature_1d(k, num_splits=num_splits) + v = split_feature_1d(v, num_splits=num_splits) + + scores = ( + torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) / scale_factor + ) # [B*H*K, W/K, W/K] + + if with_shift: + # attn_mask: [K, W/K, W/K] + scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K] + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C] + + out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=shift_size_w, dims=2) + + out = out.view(b, -1, c) + + return out + + +class SelfAttnPropagation(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__( + self, + in_channels, + **kwargs, + ): + super(SelfAttnPropagation, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + feature0, + flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn(feature0, flow, local_window_radius=local_window_radius) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn( + self, + feature0, + flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + value_channel = flow.size(1) + + feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)).reshape( + b * h * w, 1, c + ) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) + + feature0_window = F.unfold( + feature0_proj, kernel_size=kernel_size, padding=local_window_radius + ) # [B, C*(2R+1)^2), H*W] + + feature0_window = ( + feature0_window.view(b, c, kernel_size**2, h, w) + .permute(0, 3, 4, 1, 2) + .reshape(b * h * w, c, kernel_size**2) + ) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold(flow, kernel_size=kernel_size, padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + + flow_window = ( + flow_window.view(b, value_channel, kernel_size**2, h, w) + .permute(0, 3, 4, 2, 1) + .reshape(b * h * w, kernel_size**2, value_channel) + ) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / (c**0.5) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = ( + torch.matmul(prob, flow_window).view(b, h, w, value_channel).permute(0, 3, 1, 2).contiguous() + ) # [B, 2, H, W] + + return out diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/backbone.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2cc19f7dae5013da0c6a22d50e4bfabfed8ee6 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/backbone.py @@ -0,0 +1,128 @@ +import torch.nn as nn + +from .trident_conv import MultiScaleTridentConv + + +class ResidualBlock(nn.Module): + def __init__( + self, + in_planes, + planes, + norm_layer=nn.InstanceNorm2d, + stride=1, + dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, dilation=dilation, padding=dilation, stride=stride, bias=False + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, dilation=dilation, padding=dilation, bias=False) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__( + self, + output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 + self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer( + feature_dims[2], + stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv( + output_dim, + output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) + layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/geometry.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..df4d8e38d8afabe7f4e8a69724c75427dec9bd2b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/geometry.py @@ -0,0 +1,200 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid( + [torch.linspace(w_min, w_max, len_w, device=device), torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode="zeros"): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +def back_project(depth, intrinsics): + # Back project 2D pixel coords to 3D points + # depth: [B, H, W] + # intrinsics: [B, 3, 3] + b, h, w = depth.shape + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + + intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3] + + points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W] + + return points + + +def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): + # Transform 3D points from reference camera to target camera + # points_ref: [B, 3, H, W] + # extrinsics_ref: [B, 4, 4] + # extrinsics_tgt: [B, 4, 4] + # extrinsics_rel: [B, 4, 4], relative pose transform + b, _, h, w = points_ref.shape + + if extrinsics_rel is None: + extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] + + points_tgt = ( + torch.bmm(extrinsics_rel[:, :3, :3], points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] + ) # [B, 3, H*W] + + points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W] + + return points_tgt + + +def reproject(points_tgt, intrinsics, return_mask=False): + # reproject to target view + # points_tgt: [B, 3, H, W] + # intrinsics: [B, 3, 3] + + b, _, h, w = points_tgt.shape + + proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W] + + X = proj_points[:, 0] + Y = proj_points[:, 1] + Z = proj_points[:, 2].clamp(min=1e-3) + + pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale + + if return_mask: + # valid mask in pixel space + mask = ( + (pixel_coords[:, 0] >= 0) + & (pixel_coords[:, 0] <= (w - 1)) + & (pixel_coords[:, 1] >= 0) + & (pixel_coords[:, 1] <= (h - 1)) + ) # [B, H, W] + + return pixel_coords, mask + + return pixel_coords + + +def reproject_coords( + depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, return_mask=False +): + # Compute reprojection sample coords + points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] + points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) + + if return_mask: + reproj_coords, mask = reproject(points_tgt, intrinsics, return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords, mask + + reproj_coords = reproject(points_tgt, intrinsics, return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords + + +def compute_flow_with_depth_pose( + depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, return_mask=False +): + b, h, w = depth_ref.shape + coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] + + if return_mask: + reproj_coords, mask = reproject_coords( + depth_ref, + intrinsics, + extrinsics_ref, + extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask, + ) # [B, 2, H, W] + rigid_flow = reproj_coords - coords_init + + return rigid_flow, mask + + reproj_coords = reproject_coords( + depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel, return_mask=return_mask + ) # [B, 2, H, W] + + rigid_flow = reproj_coords - coords_init + + return rigid_flow diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/matching.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5e103d742b16edff87835a1cd4db45e15775ad --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/matching.py @@ -0,0 +1,307 @@ +import torch +import torch.nn.functional as F + +from .geometry import coords_grid, generate_window_grid, normalize_coords + + +def global_correlation_softmax( + feature0, + feature1, + pred_bidir_flow=False, +): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c**0.5) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax( + feature0, + feature1, + local_radius, + padding_mode="zeros", +): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid( + -local_radius, local_radius, -local_radius, local_radius, local_h, local_w, device=feature0.device + ) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=True).permute( + 0, 2, 1, 3 + ) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c**0.5) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = ( + torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(b, h, w, 2).permute(0, 3, 1, 2) + ) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob + + +def local_correlation_with_flow( + feature0, + feature1, + flow, + local_radius, + padding_mode="zeros", + dilation=1, +): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid( + -local_radius, local_radius, -local_radius, local_radius, local_h, local_w, device=feature0.device + ) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2] + + # flow can be zero when using features after transformer + if not isinstance(flow, float): + sample_coords = sample_coords + flow.view(b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2] + else: + assert flow == 0.0 + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=True).permute( + 0, 2, 1, 3 + ) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c**0.5) # [B, H*W, (2R+1)^2] + + corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W] + + return corr + + +def global_correlation_softmax_stereo( + feature0, + feature1, +): + # global correlation on horizontal direction + b, c, h, w = feature0.shape + + x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W] + + feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C] + feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W] + + correlation = torch.matmul(feature0, feature1) / (c**0.5) # [B, H, W, W] + + # mask subsequent positions to make disparity positive + mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W] + valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W] + + correlation[~valid_mask] = -1e9 + + prob = F.softmax(correlation, dim=-1) # [B, H, W, W] + + correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W] + + # NOTE: unlike flow, disparity is typically positive + disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W] + + return disparity.unsqueeze(1), prob # feature resolution + + +def local_correlation_softmax_stereo( + feature0, + feature1, + local_radius, +): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2] + + local_h = 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid( + 0, 0, -local_radius, local_radius, local_h, local_w, device=feature0.device + ) # [1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, padding_mode="zeros", align_corners=True).permute( + 0, 2, 1, 3 + ) # [B, H*W, C, (2R+1)] + feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c**0.5) # [B, H*W, (2R+1)] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)] + + correspondence = ( + torch.matmul(prob.unsqueeze(-2), sample_coords_softmax) + .squeeze(-2) + .view(b, h, w, 2) + .permute(0, 3, 1, 2) + .contiguous() + ) # [B, 2, H, W] + + flow = correspondence - coords_init # flow at feature resolution + match_prob = prob + + flow_x = -flow[:, :1] # [B, 1, H, W] + + return flow_x, match_prob + + +def correlation_softmax_depth( + feature0, + feature1, + intrinsics, + pose, + depth_candidates, + depth_from_argmax=False, + pred_bidir_depth=False, +): + b, c, h, w = feature0.size() + assert depth_candidates.dim() == 4 # [B, D, H, W] + scale_factor = c**0.5 + + if pred_bidir_depth: + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + intrinsics = intrinsics.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + depth_candidates = depth_candidates.repeat(2, 1, 1, 1) + + # depth candidates are actually inverse depth + warped_feature1 = warp_with_pose_depth_candidates( + feature1, + intrinsics, + pose, + 1.0 / depth_candidates, + ) # [B, C, D, H, W] + + correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W] + + match_prob = F.softmax(correlation, dim=1) # [B, D, H, W] + + # for cross-task transfer (flow -> depth), extract depth with argmax at test time + if depth_from_argmax: + index = torch.argmax(match_prob, dim=1, keepdim=True) + depth = torch.gather(depth_candidates, dim=1, index=index) + else: + depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W] + + return depth, match_prob + + +def warp_with_pose_depth_candidates( + feature1, + intrinsics, + pose, + depth, + clamp_min_depth=1e-3, +): + """ + feature1: [B, C, H, W] + intrinsics: [B, 3, 3] + pose: [B, 4, 4] + depth: [B, D, H, W] + """ + + assert intrinsics.size(1) == intrinsics.size(2) == 3 + assert pose.size(1) == pose.size(2) == 4 + assert depth.dim() == 4 + + b, d, h, w = depth.size() + c = feature1.size(1) + + with torch.no_grad(): + # pixel coordinates + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + # back project to 3D and transform viewpoint + points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] + points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat(1, 1, d, 1) * depth.view( + b, 1, d, h * w + ) # [B, 3, D, H*W] + points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] + # reproject to 2D image plane + points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W] + pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W] + + # normalize to [-1, 1] + x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] + + # sample features + warped_feature = F.grid_sample( + feature1, grid.view(b, d * h, w, 2), mode="bilinear", padding_mode="zeros", align_corners=True + ).view( + b, c, d, h, w + ) # [B, C, D, H, W] + + return warped_feature diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/position.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/position.py new file mode 100644 index 0000000000000000000000000000000000000000..619f3568d4c81f41316010be6a866a0e115cfc80 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/position.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import math + +import torch +import torch.nn as nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/reg_refine.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/reg_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..965f4cac62a8db3b42187b9cdbc2f679a70e6ac3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/reg_refine.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__( + self, + input_dim=128, + hidden_dim=256, + out_dim=2, + ): + super(FlowHead, self).__init__() + + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv2(self.relu(self.conv1(x))) + + return out + + +class SepConvGRU(nn.Module): + def __init__( + self, + hidden_dim=128, + input_dim=192 + 128, + kernel_size=5, + ): + padding = (kernel_size - 1) // 2 + + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + + self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__( + self, + corr_channels=324, + flow_channels=2, + ): + super(BasicMotionEncoder, self).__init__() + + self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicUpdateBlock(nn.Module): + def __init__( + self, + corr_channels=324, + hidden_dim=128, + context_dim=128, + downsample_factor=8, + flow_dim=2, + bilinear_up=False, + ): + super(BasicUpdateBlock, self).__init__() + + self.encoder = BasicMotionEncoder( + corr_channels=corr_channels, + flow_channels=flow_dim, + ) + + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim) + + self.flow_head = FlowHead( + hidden_dim, + hidden_dim=256, + out_dim=flow_dim, + ) + + if bilinear_up: + self.mask = None + else: + self.mask = nn.Sequential( + nn.Conv2d(hidden_dim, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, downsample_factor**2 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + if self.mask is not None: + mask = self.mask(net) + else: + mask = None + + return net, mask, delta_flow diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/transformer.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7fdffd17feb0328260f1a93b778801337d14a2c3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/transformer.py @@ -0,0 +1,339 @@ +import torch +import torch.nn as nn + +from .attention import ( + single_head_full_attention, + single_head_full_attention_1d, + single_head_split_window_attention, + single_head_split_window_attention_1d, +) +from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d + + +class TransformerLayer(nn.Module): + def __init__( + self, + d_model=128, + nhead=1, + no_ffn=False, + ffn_dim_expansion=4, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.no_ffn = no_ffn + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward( + self, + source, + target, + height=None, + width=None, + shifted_window_attn_mask=None, + shifted_window_attn_mask_1d=None, + attn_type="swin", + with_shift=False, + attn_num_splits=None, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # for stereo: 2d attn in self-attn, 1d attn in cross-attn + is_self_attn = (query - key).abs().max() < 1e-6 + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if attn_type == "swin" and attn_num_splits > 1: # self, cross-attn: both swin 2d + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention( + query, + key, + value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + + elif attn_type == "self_swin2d_cross_1d": # self-attn: swin 2d, cross-attn: full 1d + if self.nhead > 1: + raise NotImplementedError + else: + if is_self_attn: + if attn_num_splits > 1: + message = single_head_split_window_attention( + query, + key, + value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + # full 2d attn + message = single_head_full_attention(query, key, value) # [N, L, C] + + else: + # cross attn 1d + message = single_head_full_attention_1d( + query, + key, + value, + h=height, + w=width, + ) + + elif attn_type == "self_swin2d_cross_swin1d": # self-attn: swin 2d, cross-attn: swin 1d + if self.nhead > 1: + raise NotImplementedError + else: + if is_self_attn: + if attn_num_splits > 1: + # self attn shift window + message = single_head_split_window_attention( + query, + key, + value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + # full 2d attn + message = single_head_full_attention(query, key, value) # [N, L, C] + else: + if attn_num_splits > 1: + assert shifted_window_attn_mask_1d is not None + # cross attn 1d shift + message = single_head_split_window_attention_1d( + query, + key, + value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask_1d, + ) + else: + message = single_head_full_attention_1d( + query, + key, + value, + h=height, + w=width, + ) + + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__( + self, + d_model=128, + nhead=1, + ffn_dim_expansion=4, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer( + d_model=d_model, + nhead=nhead, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + ) + + self.cross_attn_ffn = TransformerLayer( + d_model=d_model, + nhead=nhead, + ffn_dim_expansion=ffn_dim_expansion, + ) + + def forward( + self, + source, + target, + height=None, + width=None, + shifted_window_attn_mask=None, + shifted_window_attn_mask_1d=None, + attn_type="swin", + with_shift=False, + attn_num_splits=None, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn( + source, + source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_type=attn_type, + with_shift=with_shift, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn( + source, + target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, + attn_type=attn_type, + with_shift=with_shift, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__( + self, + num_layers=6, + d_model=128, + nhead=1, + ffn_dim_expansion=4, + ): + super(FeatureTransformer, self).__init__() + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList( + [ + TransformerBlock( + d_model=d_model, + nhead=nhead, + ffn_dim_expansion=ffn_dim_expansion, + ) + for i in range(num_layers) + ] + ) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + feature0, + feature1, + attn_type="swin", + attn_num_splits=None, + **kwargs, + ): + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + # 2d attention + if "swin" in attn_type and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # 1d attention + if "swin1d" in attn_type and attn_num_splits > 1: + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d( + input_w=w, + window_size_w=window_size_w, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K, W/K, W/K] + else: + shifted_window_attn_mask_1d = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for i, layer in enumerate(self.layers): + concat0 = layer( + concat0, + concat1, + height=h, + width=w, + attn_type=attn_type, + with_shift="swin" in attn_type and attn_num_splits > 1 and i % 2 == 1, + attn_num_splits=attn_num_splits, + shifted_window_attn_mask=shifted_window_attn_mask, + shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + + return feature0, feature1 diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/trident_conv.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/trident_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..d87579b95dfb5e40d7933264fcf917dbc508bb98 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/trident_conv.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], + self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/unimatch.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/unimatch.py new file mode 100644 index 0000000000000000000000000000000000000000..c625b991627d7cb378a29ba0b1091e80c32eae65 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/unimatch.py @@ -0,0 +1,393 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .attention import SelfAttnPropagation +from .backbone import CNNEncoder +from .geometry import compute_flow_with_depth_pose, flow_warp +from .matching import ( + correlation_softmax_depth, + global_correlation_softmax, + global_correlation_softmax_stereo, + local_correlation_softmax, + local_correlation_softmax_stereo, + local_correlation_with_flow, +) +from .reg_refine import BasicUpdateBlock +from .transformer import FeatureTransformer +from .utils import feature_add_position, normalize_img, upsample_flow_with_mask + + +class UniMatch(nn.Module): + def __init__( + self, + num_scales=1, + feature_channels=128, + upsample_factor=8, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + reg_refine=False, # optional local regression refinement + task="flow", + ): + super(UniMatch, self).__init__() + + self.feature_channels = feature_channels + self.num_scales = num_scales + self.upsample_factor = upsample_factor + self.reg_refine = reg_refine + + # CNN + self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) + + # Transformer + self.transformer = FeatureTransformer( + num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + ffn_dim_expansion=ffn_dim_expansion, + ) + + # propagation with self-attn + self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels) + + if not self.reg_refine or task == "depth": + # convex upsampling simiar to RAFT + # concat feature0 and low res flow as input + self.upsampler = nn.Sequential( + nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0), + ) + # thus far, all the learnable parameters are task-agnostic + + if reg_refine: + # optional task-specific local regression refinement + self.refine_proj = nn.Conv2d(128, 256, 1) + self.refine = BasicUpdateBlock( + corr_channels=(2 * 4 + 1) ** 2, + downsample_factor=upsample_factor, + flow_dim=2 if task == "flow" else 1, + bilinear_up=task == "depth", + ) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, is_depth=False): + if bilinear: + multiplier = 1 if is_depth else upsample_factor + up_flow = ( + F.interpolate(flow, scale_factor=upsample_factor, mode="bilinear", align_corners=True) * multiplier + ) + else: + concat = torch.cat((flow, feature), dim=1) + mask = self.upsampler(concat) + up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor, is_depth=is_depth) + + return up_flow + + def forward( + self, + img0, + img1, + attn_type=None, + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + num_reg_refine=1, + pred_bidir_flow=False, + task="flow", + intrinsics=None, + pose=None, # relative pose transform + min_depth=1.0 / 0.5, # inverse depth range + max_depth=1.0 / 10, + num_depth_candidates=64, + depth_from_argmax=False, + pred_bidir_depth=False, + **kwargs, + ): + if pred_bidir_flow: + assert task == "flow" + + if task == "depth": + assert self.num_scales == 1 # multi-scale depth model is not supported yet + + results_dict = {} + flow_preds = [] + + if task == "flow": + # stereo and depth tasks have normalized img in dataloader + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + # list of features, resolution low to high + feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features + + flow = None + + if task != "depth": + assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales + else: + assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1 + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + + feature0_ori, feature1_ori = feature0, feature1 + + upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) + + if task == "depth": + # scale intrinsics + intrinsics_curr = intrinsics.clone() + intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor + + if scale_idx > 0: + assert task != "depth" # not supported for multi-scale depth model + flow = F.interpolate(flow, scale_factor=2, mode="bilinear", align_corners=True) * 2 + + if flow is not None: + assert task != "depth" + flow = flow.detach() + + if task == "stereo": + # construct flow vector for disparity + # flow here is actually disparity + zeros = torch.zeros_like(flow) # [B, 1, H, W] + # NOTE: reverse disp, disparity is positive + displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W] + feature1 = flow_warp(feature1, displace) # [B, C, H, W] + elif task == "flow": + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + else: + raise NotImplementedError + + attn_splits = attn_splits_list[scale_idx] + if task != "depth": + corr_radius = corr_radius_list[scale_idx] + prop_radius = prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) + + # Transformer + feature0, feature1 = self.transformer( + feature0, + feature1, + attn_type=attn_type, + attn_num_splits=attn_splits, + ) + + # correlation and softmax + if task == "depth": + # first generate depth candidates + b, _, h, w = feature0.size() + depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0) + depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat( + b, 1, h, w + ) # [B, D, H, W] + + flow_pred = correlation_softmax_depth( + feature0, + feature1, + intrinsics_curr, + pose, + depth_candidates=depth_candidates, + depth_from_argmax=depth_from_argmax, + pred_bidir_depth=pred_bidir_depth, + )[0] + + else: + if corr_radius == -1: # global matching + if task == "flow": + flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] + elif task == "stereo": + flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0] + else: + raise NotImplementedError + else: # local matching + if task == "flow": + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] + elif task == "stereo": + flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0] + else: + raise NotImplementedError + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + if task == "stereo": + flow = flow.clamp(min=0) # positive disparity + + # upsample to the original resolution for supervison at training time only + if self.training: + flow_bilinear = self.upsample_flow( + flow, None, bilinear=True, upsample_factor=upsample_factor, is_depth=task == "depth" + ) + flow_preds.append(flow_bilinear) + + # flow propagation with self-attn + if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0: + feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation + + flow = self.feature_flow_attn( + feature0, + flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius, + ) + + # bilinear exclude the last one + if self.training and scale_idx < self.num_scales - 1: + flow_up = self.upsample_flow( + flow, feature0, bilinear=True, upsample_factor=upsample_factor, is_depth=task == "depth" + ) + flow_preds.append(flow_up) + + if scale_idx == self.num_scales - 1: + if not self.reg_refine: + # upsample to the original image resolution + + if task == "stereo": + flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + flow_up_pad = self.upsample_flow(flow_pad, feature0) + flow_up = -flow_up_pad[:, :1] # [B, 1, H, W] + elif task == "depth": + depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + depth_up_pad = self.upsample_flow(depth_pad, feature0, is_depth=True).clamp( + min=min_depth, max=max_depth + ) + flow_up = depth_up_pad[:, :1] # [B, 1, H, W] + else: + flow_up = self.upsample_flow(flow, feature0) + + flow_preds.append(flow_up) + else: + # task-specific local regression refinement + # supervise current flow + if self.training: + flow_up = self.upsample_flow( + flow, feature0, bilinear=True, upsample_factor=upsample_factor, is_depth=task == "depth" + ) + flow_preds.append(flow_up) + + assert num_reg_refine > 0 + for refine_iter_idx in range(num_reg_refine): + flow = flow.detach() + + if task == "stereo": + zeros = torch.zeros_like(flow) # [B, 1, H, W] + # NOTE: reverse disp, disparity is positive + displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W] + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=displace, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + elif task == "depth": + if pred_bidir_depth and refine_iter_idx == 0: + intrinsics_curr = intrinsics_curr.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + + feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori), dim=0), torch.cat( + (feature1_ori, feature0_ori), dim=0 + ) + + flow_from_depth = compute_flow_with_depth_pose( + 1.0 / flow.squeeze(1), + intrinsics_curr, + extrinsics_rel=pose, + ) + + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=flow_from_depth, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + + else: + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=flow, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + + proj = self.refine_proj(feature0) + + net, inp = torch.chunk(proj, chunks=2, dim=1) + + net = torch.tanh(net) + inp = torch.relu(inp) + + net, up_mask, residual_flow = self.refine( + net, + inp, + correlation, + flow.clone(), + ) + + if task == "depth": + flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth) + else: + flow = flow + residual_flow + + if task == "stereo": + flow = flow.clamp(min=0) # positive + + if self.training or refine_iter_idx == num_reg_refine - 1: + if task == "depth": + if refine_iter_idx < num_reg_refine - 1: + # bilinear upsampling + flow_up = self.upsample_flow( + flow, feature0, bilinear=True, upsample_factor=upsample_factor, is_depth=True + ) + else: + # last one convex upsampling + # NOTE: clamp depth due to the zero padding in the unfold in the convex upsampling + # pad depth to 2 channels as flow + depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + depth_up_pad = self.upsample_flow(depth_pad, feature0, is_depth=True).clamp( + min=min_depth, max=max_depth + ) + flow_up = depth_up_pad[:, :1] # [B, 1, H, W] + + else: + flow_up = upsample_flow_with_mask( + flow, up_mask, upsample_factor=self.upsample_factor, is_depth=task == "depth" + ) + + flow_preds.append(flow_up) + + if task == "stereo": + for i in range(len(flow_preds)): + flow_preds[i] = flow_preds[i].squeeze(1) # [B, H, W] + + # convert inverse depth to depth + if task == "depth": + for i in range(len(flow_preds)): + flow_preds[i] = 1.0 / flow_preds[i].squeeze(1) # [B, H, W] + + results_dict.update({"flow_preds": flow_preds}) + + return results_dict diff --git a/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/utils.py b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..60f40bea290ddd9a3f36adc7b4defb6e26588d1b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora/tools/scoring/optical_flow/unimatch/utils.py @@ -0,0 +1,219 @@ +import torch +import torch.nn.functional as F + +from .position import PositionEmbeddingSine + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid( + [torch.linspace(w_min, w_max, len_w, device=device), torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 / 255.0 - mean) / std + img1 = (img1 / 255.0 - mean) / std + + return img0, img1 + + +def split_feature( + feature, + num_splits=2, + channel_last=False, +): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c) + .permute(0, 1, 3, 2, 4, 5) + .reshape(b_new, h_new, w_new, c) + ) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits) + .permute(0, 2, 4, 1, 3, 5) + .reshape(b_new, c, h_new, w_new) + ) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits( + splits, + num_splits=2, + channel_last=False, +): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = ( + splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(new_b, num_splits * h, num_splits * w, c) + ) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = ( + splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(new_b, c, num_splits * h, num_splits * w) + ) # [B, C, H, W] + + return merge + + +def generate_shift_window_attn_mask( + input_resolution, window_size_h, window_size_w, shift_size_h, shift_size_w, device=torch.device("cuda") +): + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = (slice(0, -window_size_h), slice(-window_size_h, -shift_size_h), slice(-shift_size_h, None)) + w_slices = (slice(0, -window_size_w), slice(-window_size_w, -shift_size_w), slice(-shift_size_w, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 + + +def upsample_flow_with_mask(flow, up_mask, upsample_factor, is_depth=False): + # convex upsampling following raft + + mask = up_mask + b, flow_channel, h, w = flow.shape + mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + multiplier = 1 if is_depth else upsample_factor + up_flow = F.unfold(multiplier * flow, [3, 3], padding=1) + up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h, upsample_factor * w) # [B, 2, K*H, K*W] + + return up_flow + + +def split_feature_1d( + feature, + num_splits=2, +): + # feature: [B, W, C] + b, w, c = feature.size() + assert w % num_splits == 0 + + b_new = b * num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, w // num_splits, c).view(b_new, w_new, c) # [B*K, W/K, C] + + return feature + + +def merge_splits_1d( + splits, + h, + num_splits=2, +): + b, w, c = splits.size() + new_b = b // num_splits // h + + splits = splits.view(new_b, h, num_splits, w, c) + merge = splits.view(new_b, h, num_splits * w, c) # [B, H, W, C] + + return merge + + +def window_partition_1d(x, window_size_w): + """ + Args: + x: (B, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, C) + """ + B, W, C = x.shape + x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C) + return x + + +def generate_shift_window_attn_mask_1d(input_w, window_size_w, shift_size_w, device=torch.device("cuda")): + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1 + w_slices = (slice(0, -window_size_w), slice(-window_size_w, -shift_size_w), slice(-shift_size_w, None)) + cnt = 0 + for w in w_slices: + img_mask[:, w, :] = cnt + cnt += 1 + + mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1 + mask_windows = mask_windows.view(-1, window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask diff --git a/src/videogen_hub/pipelines/opensora_plan/__init__.py b/src/videogen_hub/pipelines/opensora_plan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e596852e048bae4c9620205e89963bab7ab3010a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/__init__.py @@ -0,0 +1,3 @@ +import sys +sys.path.insert(0, './src/videogen_hub/pipelines/') +sys.path.insert(0, './src/videogen_hub/pipelines/opensora_plan/') \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b695bd83d3a1ca02b8c4d83caf3bf7afedb91004 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/__init__.py @@ -0,0 +1,3 @@ +import sys + +sys.path.insert(0, './src/videogen_hub/pipelines/opensora_plan/opensora') diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43a6ef0d4c13a48c6a4c303ee955333c5e8eac61 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/__init__.py @@ -0,0 +1,30 @@ +from .imagebase import imagebase_ae, imagebase_ae_stride, imagebase_ae_channel +from .videobase import videobase_ae, videobase_ae_stride, videobase_ae_channel +from .videobase import ( + VQVAEConfiguration, + VQVAEModel, + VQVAETrainer, + CausalVQVAEModel, + CausalVQVAEConfiguration, + CausalVQVAETrainer +) + +ae_stride_config = {} +ae_stride_config.update(imagebase_ae_stride) +ae_stride_config.update(videobase_ae_stride) + +ae_channel_config = {} +ae_channel_config.update(imagebase_ae_channel) +ae_channel_config.update(videobase_ae_channel) + +def getae(args): + """deprecation""" + ae = imagebase_ae.get(args.ae, None) or videobase_ae.get(args.ae, None) + assert ae is not None + return ae(args.ae) + +def getae_wrapper(ae): + """deprecation""" + ae = imagebase_ae.get(ae, None) or videobase_ae.get(ae, None) + assert ae is not None + return ae \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12eeb327ffdec68d152adfcbc0da8191d76ca0f5 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/__init__.py @@ -0,0 +1,30 @@ +from .vae.vae import HFVAEWrapper +from .vae.vae import SDVAEWrapper +from .vqvae.vqvae import SDVQVAEWrapper + +vae = ['stabilityai/sd-vae-ft-mse', 'stabilityai/sd-vae-ft-ema'] +vqvae = ['vqgan_imagenet_f16_1024', 'vqgan_imagenet_f16_16384', 'vqgan_gumbel_f8'] + +imagebase_ae_stride = { + 'stabilityai/sd-vae-ft-mse': [1, 8, 8], + 'stabilityai/sd-vae-ft-ema': [1, 8, 8], + 'vqgan_imagenet_f16_1024': [1, 16, 16], + 'vqgan_imagenet_f16_16384': [1, 16, 16], + 'vqgan_gumbel_f8': [1, 8, 8], +} + +imagebase_ae_channel = { + 'stabilityai/sd-vae-ft-mse': 4, + 'stabilityai/sd-vae-ft-ema': 4, + 'vqgan_imagenet_f16_1024': -1, + 'vqgan_imagenet_f16_16384': -1, + 'vqgan_gumbel_f8': -1, +} + +imagebase_ae = { + 'stabilityai/sd-vae-ft-mse': HFVAEWrapper, + 'stabilityai/sd-vae-ft-ema': HFVAEWrapper, + 'vqgan_imagenet_f16_1024': SDVQVAEWrapper, + 'vqgan_imagenet_f16_16384': SDVQVAEWrapper, + 'vqgan_gumbel_f8': SDVQVAEWrapper, +} \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vae/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vae/vae.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vae/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..4f197ae12e34ba8d4bbf44370adfe49c28f3bef0 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vae/vae.py @@ -0,0 +1,38 @@ +from einops import rearrange +from torch import nn +from diffusers.models import AutoencoderKL + + +class HFVAEWrapper(nn.Module): + def __init__(self, hfvae='mse'): + super(HFVAEWrapper, self).__init__() + self.vae = AutoencoderKL.from_pretrained(hfvae, cache_dir='cache_dir') + def encode(self, x): # b c h w + t = 0 + if x.ndim == 5: + b, c, t, h, w = x.shape + x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous() + x = self.vae.encode(x).latent_dist.sample().mul_(0.18215) + if t != 0: + x = rearrange(x, '(b t) c h w -> b c t h w', t=t).contiguous() + return x + def decode(self, x): + t = 0 + if x.ndim == 5: + b, c, t, h, w = x.shape + x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous() + x = self.vae.decode(x / 0.18215).sample + if t != 0: + x = rearrange(x, '(b t) c h w -> b t c h w', t=t).contiguous() + return x + +class SDVAEWrapper(nn.Module): + def __init__(self): + super(SDVAEWrapper, self).__init__() + raise NotImplementedError + + def encode(self, x): # b c h w + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/model.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9a757f742019649f0173235c2f4e04fe042929 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/model.py @@ -0,0 +1,775 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, t=None): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VUNet(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + in_channels, c_channels, + resolution, z_channels, use_timestep=False, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(c_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.z_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=1, + stride=1, + padding=0) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=2*block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, z): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + z = self.z_in(z) + h = torch.cat((h,z),dim=1) + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/quantize.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..148062e9ea479374a28f1de6aa4c4c3a8c91c826 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/quantize.py @@ -0,0 +1,447 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # .........\end + + # with: + # .........\start + # min_encoding_indices = torch.argmin(d, dim=1) + # z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:, None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, + remap=None, unknown_index="random"): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:, self.used, ...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:, self.used, ...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b * h * w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits == False, "Only for interface compatible with Gumbel" + assert return_logits == False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + remap=None, unknown_index="random"): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + z = rearrange(z, 'b c h w -> b h w c') + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + # EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + # EMA embedding average + embed_sum = encodings.transpose(0, 1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + # normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, (perplexity, encodings, encoding_indices) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/vqgan.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..9e9125be141193d2a0f988d4fcecddfa7cfd4a39 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/vqgan.py @@ -0,0 +1,419 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl +import argparse, os, sys, datetime, glob, importlib + +from .model import Encoder, Decoder +from .quantize import VectorQuantizer2 as VectorQuantizer +from .quantize import GumbelQuantize +from .quantize import EMAVectorQuantizer + + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.image_key = image_key + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + return x.float() + + def training_step(self, batch, batch_idx, optimizer_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQSegmentationModel(VQModel): + def __init__(self, n_labels, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1)) + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + return opt_ae + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + total_loss = log_dict_ae["val/total_loss"] + self.log("val/total_loss", total_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + return aeloss + + @torch.no_grad() + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + # convert logits to indices + xrec = torch.argmax(xrec, dim=1, keepdim=True) + xrec = F.one_hot(xrec, num_classes=x.shape[1]) + xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + +class VQNoDiscModel(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None + ): + super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim, + ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key, + colorize_nlabels=colorize_nlabels) + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train") + output = pl.TrainResult(minimize=aeloss) + output.log("train/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return output + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val") + rec_loss = log_dict_ae["val/rec_loss"] + output = pl.EvalResult(checkpoint_on=rec_loss) + output.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae) + + return output + + def configure_optimizers(self): + optimizer = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=self.learning_rate, betas=(0.5, 0.9)) + return optimizer + + +class GumbelVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + temperature_scheduler_config, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + kl_weight=1e-8, + remap=None, + ): + + z_channels = ddconfig["z_channels"] + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + + self.loss.n_classes = n_embed + self.vocab_size = n_embed + + self.quantize = GumbelQuantize(z_channels, embed_dim, + n_embed=n_embed, + kl_weight=kl_weight, temp_init=1.0, + remap=remap) + + self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def temperature_scheduling(self): + self.quantize.temperature = self.temperature_scheduler(self.global_step) + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode_code(self, code_b): + raise NotImplementedError + + def training_step(self, batch, batch_idx, optimizer_idx): + self.temperature_scheduling() + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + # encode + h = self.encoder(x) + h = self.quant_conv(h) + quant, _, _ = self.quantize(h) + # decode + x_rec = self.decode(quant) + log["inputs"] = x + log["reconstructions"] = x_rec + return log + + +class EMAVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + self.quantize = EMAVectorQuantizer(n_embed=n_embed, + embedding_dim=embed_dim, + beta=0.25, + remap=remap) + def configure_optimizers(self): + lr = self.learning_rate + #Remove self.quantize from parameter list since it is updated via EMA + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/vqvae.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..c43f8c42b5d883da6069e3d92cc50e7a7d7264c3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/imagebase/vqvae/vqvae.py @@ -0,0 +1,34 @@ +from torch import nn +import yaml +import torch +from omegaconf import OmegaConf +from .vqgan import VQModel, GumbelVQ + +def load_config(config_path, display=False): + config = OmegaConf.load(config_path) + if display: + print(yaml.dump(OmegaConf.to_container(config))) + return config + + +def load_vqgan(config, ckpt_path=None, is_gumbel=False): + if is_gumbel: + model = GumbelVQ(**config.model.params) + else: + model = VQModel(**config.model.params) + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] + missing, unexpected = model.load_state_dict(sd, strict=False) + return model.eval() + + +class SDVQVAEWrapper(nn.Module): + def __init__(self, name): + super(SDVQVAEWrapper, self).__init__() + raise NotImplementedError + + def encode(self, x): # b c h w + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c70a565100c744d8c0bfc649ce94a969bd3410a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/__init__.py @@ -0,0 +1,53 @@ +from .vqvae import ( + VQVAEConfiguration, + VQVAEModel, + VQVAETrainer, + VQVAEModelWrapper +) +from .causal_vqvae import ( + CausalVQVAEConfiguration, + CausalVQVAETrainer, + CausalVQVAEModel, CausalVQVAEModelWrapper +) +from .causal_vae import ( + CausalVAEModel, CausalVAEModelWrapper +) + + +videobase_ae_stride = { + 'CausalVAEModel_4x8x8': [4, 8, 8], + 'CausalVQVAEModel_4x4x4': [4, 4, 4], + 'CausalVQVAEModel_4x8x8': [4, 8, 8], + 'VQVAEModel_4x4x4': [4, 4, 4], + 'OpenVQVAEModel_4x4x4': [4, 4, 4], + 'VQVAEModel_4x8x8': [4, 8, 8], + 'bair_stride4x2x2': [4, 2, 2], + 'ucf101_stride4x4x4': [4, 4, 4], + 'kinetics_stride4x4x4': [4, 4, 4], + 'kinetics_stride2x4x4': [2, 4, 4], +} + +videobase_ae_channel = { + 'CausalVAEModel_4x8x8': 4, + 'CausalVQVAEModel_4x4x4': 4, + 'CausalVQVAEModel_4x8x8': 4, + 'VQVAEModel_4x4x4': 4, + 'OpenVQVAEModel_4x4x4': 4, + 'VQVAEModel_4x8x8': 4, + 'bair_stride4x2x2': 256, + 'ucf101_stride4x4x4': 256, + 'kinetics_stride4x4x4': 256, + 'kinetics_stride2x4x4': 256, +} + +videobase_ae = { + 'CausalVAEModel_4x8x8': CausalVAEModelWrapper, + 'CausalVQVAEModel_4x4x4': CausalVQVAEModelWrapper, + 'CausalVQVAEModel_4x8x8': CausalVQVAEModelWrapper, + 'VQVAEModel_4x4x4': VQVAEModelWrapper, + 'VQVAEModel_4x8x8': VQVAEModelWrapper, + "bair_stride4x2x2": VQVAEModelWrapper, + "ucf101_stride4x4x4": VQVAEModelWrapper, + "kinetics_stride4x4x4": VQVAEModelWrapper, + "kinetics_stride2x4x4": VQVAEModelWrapper, +} diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vae/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c2d89587179eb5874365efbce75c57280e8aa4 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vae/__init__.py @@ -0,0 +1,26 @@ +from .modeling_causalvae import CausalVAEModel + +from einops import rearrange +from torch import nn + +class CausalVAEModelWrapper(nn.Module): + def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs): + super(CausalVAEModelWrapper, self).__init__() + # if os.path.exists(ckpt): + # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) + self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) + def encode(self, x): # b c t h w + # x = self.vae.encode(x).sample() + x = self.vae.encode(x).sample().mul_(0.18215) + return x + def decode(self, x): + # x = self.vae.decode(x) + x = self.vae.decode(x / 0.18215) + x = rearrange(x, 'b c t h w -> b t c h w').contiguous() + return x + + def dtype(self): + return self.vae.dtype + # + # def device(self): + # return self.vae.device \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vae/modeling_causalvae.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vae/modeling_causalvae.py new file mode 100644 index 0000000000000000000000000000000000000000..264e92f0fa2d4046c46bd10c46e2a40380382c4f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vae/modeling_causalvae.py @@ -0,0 +1,712 @@ +from ..modeling_videobase import VideoBaseAE_PL +from ..modules import Normalize +from ..modules.ops import nonlinearity +from typing import List, Tuple +import torch.nn as nn +from ..utils.module_utils import resolve_str_to_obj, Module +from ..utils.distrib_utils import DiagonalGaussianDistribution +from ..utils.scheduler_utils import cosine_scheduler +import torch +from diffusers.configuration_utils import register_to_config + + +class Encoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in: Module = "Conv2d", + conv_out: Module = "CasualConv3d", + attention: Module = "AttnBlock", + resnet_blocks: Tuple[Module] = ( + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock3D", + ), + spatial_downsample: Tuple[Module] = ( + "Downsample", + "Downsample", + "Downsample", + "", + ), + temporal_downsample: Tuple[Module] = ("", "", "TimeDownsampleRes2x", ""), + mid_resnet: Module = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + double_z: bool = True, + ) -> None: + super().__init__() + assert len(resnet_blocks) == len(hidden_size_mult), print( + hidden_size_mult, resnet_blocks + ) + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + self.conv_in = resolve_str_to_obj(conv_in)( + 3, hidden_size, kernel_size=3, stride=1, padding=1 + ) + + # ---- Downsample ---- + curr_res = resolution + in_ch_mult = (1,) + tuple(hidden_size_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = hidden_size * in_ch_mult[i_level] + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if spatial_downsample[i_level]: + down.downsample = resolve_str_to_obj(spatial_downsample[i_level])( + block_in, block_in + ) + curr_res = curr_res // 2 + if temporal_downsample[i_level]: + down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])( + block_in, block_in + ) + self.down.append(down) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if hasattr(self.down[i_level], "downsample"): + hs.append(self.down[i_level].downsample(hs[-1])) + if hasattr(self.down[i_level], "time_downsample"): + hs_down = self.down[i_level].time_downsample(hs[-1]) + hs.append(hs_down) + + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in: Module = "Conv2d", + conv_out: Module = "CasualConv3d", + attention: Module = "AttnBlock", + resnet_blocks: Tuple[Module] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + spatial_upsample: Tuple[Module] = ( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + temporal_upsample: Tuple[Module] = ("", "", "", "TimeUpsampleRes2x"), + mid_resnet: Module = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + ): + super().__init__() + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.conv_in = resolve_str_to_obj(conv_in)( + z_channels, block_in, kernel_size=3, padding=1 + ) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + + # ---- Upsample ---- + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if spatial_upsample[i_level]: + up.upsample = resolve_str_to_obj(spatial_upsample[i_level])( + block_in, block_in + ) + curr_res = curr_res * 2 + if temporal_upsample[i_level]: + up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])( + block_in, block_in + ) + self.up.insert(0, up) + + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)( + block_in, 3, kernel_size=3, padding=1 + ) + + def forward(self, z): + h = self.conv_in(z) + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if hasattr(self.up[i_level], "upsample"): + h = self.up[i_level].upsample(h) + if hasattr(self.up[i_level], "time_upsample"): + h = self.up[i_level].time_upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class CausalVAEModel(VideoBaseAE_PL): + + @register_to_config + def __init__( + self, + lr: float = 1e-5, + hidden_size: int = 128, + z_channels: int = 4, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = [], + dropout: float = 0.0, + resolution: int = 256, + double_z: bool = True, + embed_dim: int = 4, + num_res_blocks: int = 2, + loss_type: str = "videogen_hub.pipelines.opensora_plan.opensora.models.ae.videobase.losses.LPIPSWithDiscriminator", + loss_params: dict = { + "kl_weight": 0.000001, + "logvar_init": 0.0, + "disc_start": 2001, + "disc_weight": 0.5, + }, + q_conv: str = "CausalConv3d", + encoder_conv_in: Module = "CausalConv3d", + encoder_conv_out: Module = "CausalConv3d", + encoder_attention: Module = "AttnBlock3D", + encoder_resnet_blocks: Tuple[Module] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + encoder_spatial_downsample: Tuple[Module] = ( + "SpatialDownsample2x", + "SpatialDownsample2x", + "SpatialDownsample2x", + "", + ), + encoder_temporal_downsample: Tuple[Module] = ( + "", + "TimeDownsample2x", + "TimeDownsample2x", + "", + ), + encoder_mid_resnet: Module = "ResnetBlock3D", + decoder_conv_in: Module = "CausalConv3d", + decoder_conv_out: Module = "CausalConv3d", + decoder_attention: Module = "AttnBlock3D", + decoder_resnet_blocks: Tuple[Module] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + decoder_spatial_upsample: Tuple[Module] = ( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + decoder_temporal_upsample: Tuple[Module] = ("", "", "TimeUpsample2x", "TimeUpsample2x"), + decoder_mid_resnet: Module = "ResnetBlock3D", + ) -> None: + super().__init__() + self.tile_sample_min_size = 256 + self.tile_sample_min_size_t = 65 + self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) + t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] + self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** len(t_down_ratio))) + 1 + self.tile_overlap_factor = 0.25 + self.use_tiling = False + + self.learning_rate = lr + self.lr_g_factor = 1.0 + + self.loss = resolve_str_to_obj(loss_type, append=False)( + **loss_params + ) + + self.encoder = Encoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=encoder_conv_in, + conv_out=encoder_conv_out, + attention=encoder_attention, + resnet_blocks=encoder_resnet_blocks, + spatial_downsample=encoder_spatial_downsample, + temporal_downsample=encoder_temporal_downsample, + mid_resnet=encoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + double_z=double_z, + ) + + self.decoder = Decoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=decoder_conv_in, + conv_out=decoder_conv_out, + attention=decoder_attention, + resnet_blocks=decoder_resnet_blocks, + spatial_upsample=decoder_spatial_upsample, + temporal_upsample=decoder_temporal_upsample, + mid_resnet=decoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + ) + + quant_conv_cls = resolve_str_to_obj(q_conv) + self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) + self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) + if hasattr(self.loss, "discriminator"): + self.automatic_optimization = False + + def encode(self, x): + if self.use_tiling and ( + x.shape[-1] > self.tile_sample_min_size + or x.shape[-2] > self.tile_sample_min_size + or x.shape[-3] > self.tile_sample_min_size_t + ): + return self.tiled_encode(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + if self.use_tiling and ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + or z.shape[-3] > self.tile_latent_min_size_t + ): + return self.tiled_decode(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx): + if hasattr(self.loss, "discriminator"): + return self._training_step_gan(batch, batch_idx=batch_idx) + else: + return self._training_step(batch, batch_idx=batch_idx) + + def _training_step(self, batch, batch_idx): + inputs = self.get_input(batch, "video") + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return aeloss + + def _training_step_gan(self, batch, batch_idx): + inputs = self.get_input(batch, "video") + reconstructions, posterior = self(inputs) + opt1, opt2 = self.optimizers() + + # ---- AE Loss ---- + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + opt1.zero_grad() + self.manual_backward(aeloss) + self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm") + opt1.step() + # ---- GAN Loss ---- + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + opt2.zero_grad() + self.manual_backward(discloss) + self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm") + opt2.step() + self.log_dict( + {**log_dict_ae, **log_dict_disc}, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False, + ) + + def configure_optimizers(self): + from itertools import chain + + lr = self.learning_rate + modules_to_train = [ + self.encoder.named_parameters(), + self.decoder.named_parameters(), + self.post_quant_conv.named_parameters(), + self.quant_conv.named_parameters(), + ] + params_with_time = [] + params_without_time = [] + for name, param in chain(*modules_to_train): + if "time" in name: + params_with_time.append(param) + else: + params_without_time.append(param) + optimizers = [] + opt_ae = torch.optim.Adam( + [ + {"params": params_with_time, "lr": lr}, + {"params": params_without_time, "lr": lr}, + ], + lr=lr, + betas=(0.5, 0.9), + ) + optimizers.append(opt_ae) + + if hasattr(self.loss, "discriminator"): + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) + ) + optimizers.append(opt_disc) + + return optimizers, [] + + def get_last_layer(self): + if hasattr(self.decoder.conv_out, "conv"): + return self.decoder.conv_out.conv.weight + else: + return self.decoder.conv_out.weight + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t-1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1] for i in range(len(t_chunk_idx)-1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + moments = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start: end] + if idx != 0: + moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:] + else: + moment = self.tiled_encode2d(chunk_x, return_moments=True) + moments.append(moment) + moments = torch.cat(moments, dim=2) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def tiled_decode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t-1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1] for i in range(len(t_chunk_idx)-1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + dec_ = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start: end] + if idx != 0: + dec = self.tiled_decode2d(chunk_x)[:, :, 1:] + else: + dec = self.tiled_decode2d(chunk_x) + dec_.append(dec) + dec_ = torch.cat(dec_, dim=2) + return dec_ + + def tiled_encode2d(self, x, return_moments=False): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + posterior = DiagonalGaussianDistribution(moments) + if return_moments: + return moments + return posterior + + def tiled_decode2d(self, z): + + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def enable_tiling(self, use_tiling: bool = True): + self.use_tiling = use_tiling + + def disable_tiling(self): + self.enable_tiling(False) + + def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=False): + sd = torch.load(path, map_location="cpu") + print("init from " + path) + if "state_dict" in sd: + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + + def validation_step(self, batch, batch_idx): + + from ..utils.video_utils import tensor_to_video + inputs = self.get_input(batch, 'video') + latents = self.encode(inputs).sample() + video_recon = self.decode(latents) + for idx in range(len(video_recon)): + self.logger.log_video(f"recon {batch_idx} {idx}", [tensor_to_video(video_recon[idx])], fps=[10]) + + \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..116d8956482484564e3e7b9f1b8ccd423ca93819 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/__init__.py @@ -0,0 +1,20 @@ +from .configuration_causalvqvae import CausalVQVAEConfiguration +from .modeling_causalvqvae import CausalVQVAEModel +from .trainer_causalvqvae import CausalVQVAETrainer + + +from einops import rearrange +from torch import nn + +class CausalVQVAEModelWrapper(nn.Module): + def __init__(self, ckpt): + super(CausalVQVAEModelWrapper, self).__init__() + self.vqvae = CausalVQVAEModel.load_from_checkpoint(ckpt) + def encode(self, x): # b c t h w + x = self.vqvae.pre_vq_conv(self.vqvae.encoder(x)) + return x + def decode(self, x): + vq_output = self.vqvae.codebook(x) + x = self.vqvae.decoder(self.vqvae.post_vq_conv(vq_output['embeddings'])) + x = rearrange(x, 'b c t h w -> b t c h w').contiguous() + return x \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/configuration_causalvqvae.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/configuration_causalvqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..3f18be533fa298aa30c201f16e949898f73e64b5 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/configuration_causalvqvae.py @@ -0,0 +1,30 @@ +from ..configuration_videobase import VideoBaseConfiguration +from typing import Union, Tuple + +class CausalVQVAEConfiguration(VideoBaseConfiguration): + def __init__( + self, + embedding_dim: int = 256, + n_codes: int = 2048, + n_hiddens: int = 240, + n_res_layers: int = 4, + resolution: int = 128, + sequence_length: int = 16, + time_downsample: int = 4, + spatial_downsample: int = 8, + no_pos_embd: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.embedding_dim = embedding_dim + self.n_codes = n_codes + self.n_hiddens = n_hiddens + self.n_res_layers = n_res_layers + self.resolution = resolution + self.sequence_length = sequence_length + self.time_downsample = time_downsample + self.spatial_downsample = spatial_downsample + self.no_pos_embd = no_pos_embd + + self.hidden_size = n_hiddens diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/modeling_causalvqvae.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/modeling_causalvqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..65f3f98a97f80b83782fc026070ea79fcf68004a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/modeling_causalvqvae.py @@ -0,0 +1,848 @@ +from ..modeling_videobase import VideoBaseAE +import torch +from torch import nn, Tensor +import numpy as np +import torch.distributed as dist +import torch.nn.functional as F +import math +import os +import json +from typing import Tuple, Dict, Union +from .configuration_causalvqvae import CausalVQVAEConfiguration +from einops import rearrange, pack, unpack + +# Copied from https://github.com/wilson1yan/VideoGPT +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +# Copied from https://github.com/wilson1yan/VideoGPT +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + dims = list(range(n_dims)) + del dims[src_dim] + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# Copied from https://github.com/wilson1yan/VideoGPT +def scaled_dot_product_attention(q, k, v, mask=None, attn_dropout=0.0, training=True): + # Performs scaled dot-product attention over the second to last dimension dn + + # (b, n_head, d1, ..., dn, d) + attn = torch.matmul(q, k.transpose(-1, -2)) + attn = attn / np.sqrt(q.shape[-1]) + if mask is not None: + attn = attn.masked_fill(mask == 0, float("-inf")) + attn_float = F.softmax(attn, dim=-1) + attn = attn_float.type_as(attn) # b x n_head x d1 x ... x dn x d + attn = F.dropout(attn, p=attn_dropout, training=training) + + a = torch.matmul(attn, v) # b x n_head x d1 x ... x dn x d + + return a + +def is_odd(n): + return not n % 2 == 0 + +def maybe_del_attr_(o, attr): + if hasattr(o, attr): + delattr(o, attr) + +def cast_tuple(t, length = 1): + return t if isinstance(t, tuple) else ((t,) * length) + +class SpatialDownsample2x(torch.nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (4,4), + stride: Union[int, Tuple[int]] = (2,2) + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 2) + stride = cast_tuple(stride, 2) + + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + + self.conv = torch.nn.Conv2d(self.chan_in, self.chan_out, self.kernel_size, stride=stride) + + def forward(self, x): + x = F.pad(x, self.pad_input) + x = rearrange(x, "b c f h w -> b f c h w") + x, ps = pack([x], "* c h w") + x = self.conv(x) + x = unpack(x, ps, "* c h w")[0] + x = rearrange(x, "b f c h w -> b c f h w") + return x + +class SpatialUpsample2x(torch.nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3,3), + stride: Union[int, Tuple[int]] = (1,1) + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = torch.nn.Conv2d(self.chan_in, self.chan_out, self.kernel_size, stride=stride, padding=tuple([(k - 1) // 2 for k in kernel_size])) + + def forward(self, x): + x = rearrange(x, "b c f h w -> b f c h w") + x, ps = pack([x], "* c h w") + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + x = unpack(x, ps, "* c h w")[0] + x = rearrange(x, "b f c h w -> b c f h w") + return x + +class TimeDownsample2x(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: int = 4, + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d(chan_in, chan_out, kernel_size, stride=2) + + def forward(self, x): + return self.conv(x) + +class TimeUpsample2x(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: int = 3, + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d(chan_in, chan_out, kernel_size, stride=1) + + def forward(self, x): + x = rearrange(x, "b c f h w -> b c h w f") + x, ps = pack([x], "b * f") + if x.size(-1) > 1: + x = torch.concat((x[:,:,:1], F.interpolate(x[:,:,1:], scale_factor=2.0, mode="linear")), dim=-1) + else: + x = x + x = unpack(x, ps, "b * f")[0] + x = rearrange(x, "b c h w f -> b c f h w") + x = self.conv(x) + return x + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + **kwargs + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = kernel_size[0] + stride = kwargs.pop('stride', 1) + stride = (stride, 1, 1) + total_pad = tuple([k - s for k, s in zip(kernel_size[1:], stride[1:])]) + pad_input = [] + for p in total_pad[::-1]: + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + pad_input += (0, 0) + self.padding = pad_input + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, **kwargs) + + def forward(self, x): + x = F.pad(x, self.padding) + first_frame_pad = x[:, :, :1, : ,:].repeat((1,1,self.time_kernel_size - 1,1,1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + return self.conv(x) + +# Modified from https://github.com/wilson1yan/VideoGPT +class AxialBlock(nn.Module): + def __init__(self, n_hiddens, n_head): + super().__init__() + kwargs = dict( + shape=(0,) * 3, + dim_q=n_hiddens, + dim_kv=n_hiddens, + n_head=n_head, + n_layer=1, + causal=False, + attn_type="axial", + ) + self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2), **kwargs) + self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3), **kwargs) + kwargs['causal'] = True + self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4), **kwargs) + + def forward(self, x): + x = shift_dim(x, 1, -1) + x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x) + x = shift_dim(x, -1, 1) + return x + +# Copied from https://github.com/wilson1yan/VideoGPT +class AttentionResidualBlock(nn.Module): + def __init__(self, n_hiddens, n_heads: int = 2): + super().__init__() + self.block = nn.Sequential( + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + CausalConv3d(n_hiddens, n_hiddens // 2, 3, bias=False), + nn.BatchNorm3d(n_hiddens // 2), + nn.ReLU(), + CausalConv3d(n_hiddens // 2, n_hiddens, 1, bias=False), + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + AxialBlock(n_hiddens, n_heads), + ) + + def forward(self, x): + return x + self.block(x) + +# Copied from https://github.com/wilson1yan/VideoGPT +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim): + super().__init__() + self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) + self.register_buffer("N", torch.zeros(n_codes)) + self.register_buffer("z_avg", self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + y = self._tile(flat_inputs) + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + distances = ( + (flat_inputs**2).sum(dim=1, keepdim=True) + - 2 * flat_inputs @ self.embeddings.t() + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) + ) + + encoding_indices = torch.argmin(distances, dim=1) + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) + encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) + + embeddings = F.embedding(encoding_indices, self.embeddings) + embeddings = shift_dim(embeddings, -1, 1) + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) + encode_sum = flat_inputs.t() @ encode_onehot + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + usage = (self.N.view(self.n_codes, 1) >= 1).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return dict( + embeddings=embeddings_st, + encodings=encoding_indices, + commitment_loss=commitment_loss, + perplexity=perplexity, + ) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings + +# Modified from https://github.com/wilson1yan/VideoGPT +class Encoder(nn.Module): + def __init__(self, n_hiddens, n_res_layers, time_downsample, spatial_downsample): + super().__init__() + spatial_downsample = int(math.log2(spatial_downsample)) + self.spatial_conv = nn.ModuleList() + for i in range(spatial_downsample): + in_channels = 3 if i == 0 else n_hiddens + conv = SpatialDownsample2x(in_channels, n_hiddens) + self.spatial_conv.append(conv) + self.spatial_res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + time_downsample = int(math.log2(time_downsample)) + self.time_conv = nn.ModuleList() + for i in range(time_downsample): + conv = TimeDownsample2x(n_hiddens, n_hiddens) + self.time_conv.append(conv) + self.time_res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + + def forward(self, x): + h = x + for conv in self.spatial_conv: + h = F.relu(conv(h)) + h = self.spatial_res_stack(h) + for conv in self.time_conv: + h = F.relu(conv(h)) + h = self.time_res_stack(h) + return h + +# Copied from https://github.com/wilson1yan/VideoGPT +class MultiHeadAttention(nn.Module): + def __init__( + self, shape, dim_q, dim_kv, n_head, n_layer, causal, attn_type, attn_kwargs + ): + super().__init__() + self.causal = causal + self.shape = shape + + self.d_k = dim_q // n_head + self.d_v = dim_kv // n_head + self.n_head = n_head + + self.w_qs = nn.Linear(dim_q, n_head * self.d_k, bias=False) # q + self.w_qs.weight.data.normal_(std=1.0 / np.sqrt(dim_q)) + + self.w_ks = nn.Linear(dim_kv, n_head * self.d_k, bias=False) # k + self.w_ks.weight.data.normal_(std=1.0 / np.sqrt(dim_kv)) + + self.w_vs = nn.Linear(dim_kv, n_head * self.d_v, bias=False) # v + self.w_vs.weight.data.normal_(std=1.0 / np.sqrt(dim_kv)) + + self.fc = nn.Linear(n_head * self.d_v, dim_q, bias=True) # c + self.fc.weight.data.normal_(std=1.0 / np.sqrt(dim_q * n_layer)) + + if attn_type == "full": + self.attn = FullAttention(shape, causal, **attn_kwargs) + elif attn_type == "axial": + self.attn = AxialAttention(len(shape), causal=causal, **attn_kwargs) + elif attn_type == "sparse": + self.attn = SparseAttention(shape, n_head, causal, **attn_kwargs) + + self.cache = None + + def forward(self, q, k, v, decode_step=None, decode_idx=None): + """Compute multi-head attention + Args + q, k, v: a [b, d1, ..., dn, c] tensor or + a [b, 1, ..., 1, c] tensor if decode_step is not None + + Returns + The output after performing attention + """ + + # compute k, q, v + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + q = view_range(self.w_qs(q), -1, None, (n_head, d_k)) + k = view_range(self.w_ks(k), -1, None, (n_head, d_k)) + v = view_range(self.w_vs(v), -1, None, (n_head, d_v)) + + # b x n_head x seq_len x d + # (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d) + q = shift_dim(q, -2, 1) + k = shift_dim(k, -2, 1) + v = shift_dim(v, -2, 1) + + # fast decoding + if decode_step is not None: + if decode_step == 0: + if self.causal: + k_shape = (q.shape[0], n_head, *self.shape, self.d_k) + v_shape = (q.shape[0], n_head, *self.shape, self.d_v) + self.cache = dict( + k=torch.zeros(k_shape, dtype=k.dtype, device=q.device), + v=torch.zeros(v_shape, dtype=v.dtype, device=q.device), + ) + else: + # cache only once in the non-causal case + self.cache = dict(k=k.clone(), v=v.clone()) + if self.causal: + idx = ( + slice(None, None), + slice(None, None), + *[slice(i, i + 1) for i in decode_idx], + ) + self.cache["k"][idx] = k + self.cache["v"][idx] = v + k, v = self.cache["k"], self.cache["v"] + + a = self.attn(q, k, v, decode_step, decode_idx) + + # (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d) + a = shift_dim(a, 1, -2).flatten(start_dim=-2) + a = self.fc(a) # (b x seq_len x embd_dim) + + return a + +# Copied from https://github.com/wilson1yan/VideoGPT +class Decoder(nn.Module): + def __init__(self, n_hiddens, n_res_layers, time_downsample, spatial_downsample): + super().__init__() + self.time_res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + time_downsample = int(math.log2(time_downsample)) + self.time_conv = nn.ModuleList() + for i in range(time_downsample): + convt = TimeUpsample2x(n_hiddens, n_hiddens) + self.time_conv.append(convt) + self.spatial_res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + spatial_downsample = int(math.log2(spatial_downsample)) + self.spatial_conv = nn.ModuleList() + for i in range(spatial_downsample): + out_channels = 3 if i == spatial_downsample - 1 else n_hiddens + convt = SpatialUpsample2x(n_hiddens, out_channels) + self.spatial_conv.append(convt) + + def forward(self, x): + h = self.time_res_stack(x) + for conv in self.time_conv: + h = F.relu(conv(h)) + h = self.spatial_res_stack(h) + for i, conv in enumerate(self.spatial_conv): + h = conv(h) + if i < len(self.spatial_conv) - 1: + h = F.relu(h) + return h + +# Copied from https://github.com/wilson1yan/VideoGPT +class FullAttention(nn.Module): + def __init__(self, shape, causal, attn_dropout): + super().__init__() + self.causal = causal + self.attn_dropout = attn_dropout + + seq_len = np.prod(shape) + if self.causal: + self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len))) + + def forward(self, q, k, v, decode_step, decode_idx): + mask = self.mask if self.causal else None + if decode_step is not None and mask is not None: + mask = mask[[decode_step]] + + old_shape = q.shape[2:-1] + q = q.flatten(start_dim=2, end_dim=-2) + k = k.flatten(start_dim=2, end_dim=-2) + v = v.flatten(start_dim=2, end_dim=-2) + + out = scaled_dot_product_attention( + q, k, v, mask=mask, attn_dropout=self.attn_dropout, training=self.training + ) + + return view_range(out, 2, 3, old_shape) + +# Copied from https://github.com/wilson1yan/VideoGPT +class AxialAttention(nn.Module): + def __init__(self, n_dim, axial_dim, causal=False): + super().__init__() + if axial_dim < 0: + axial_dim = 2 + n_dim + 1 + axial_dim + else: + axial_dim += 2 # account for batch, head, dim + self.causal = causal + self.axial_dim = axial_dim + + def forward(self, q, k, v, decode_step, decode_idx): + # batch, head, frame, height, width, dim + q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3) + k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3) + v = shift_dim(v, self.axial_dim, -2) + + old_shape = list(v.shape) + v = v.flatten(end_dim=-3) + + if self.causal: + mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])) if self.causal else None + if decode_step is not None and mask is not None: + mask = mask[[decode_step]] + mask = mask.to(q.device) + else: + mask = None + + out = scaled_dot_product_attention(q, k, v, mask=mask, training=self.training) + out = out.view(*old_shape) + out = shift_dim(out, -2, self.axial_dim) + return out + +# Copied from https://github.com/wilson1yan/VideoGPT +class StridedSparsityConfig(object): + """ + Strided Sparse configuration specified in https://arxiv.org/abs/1904.10509 that + generalizes to arbitrary dimensions + """ + + def __init__(self, shape, n_head, causal, block, num_local_blocks): + self.n_head = n_head + self.shape = shape + self.causal = causal + self.block = block + self.num_local_blocks = num_local_blocks + + assert self.num_local_blocks >= 1, "Must have at least 1 local block" + assert self.seq_len % self.block == 0, "seq len must be divisible by block size" + + self._block_shape = self._compute_block_shape() + self._block_shape_cum = self._block_shape_cum_sizes() + + @property + def seq_len(self): + return np.prod(self.shape) + + @property + def num_blocks(self): + return self.seq_len // self.block + + def set_local_layout(self, layout): + num_blocks = self.num_blocks + for row in range(0, num_blocks): + end = min(row + self.num_local_blocks, num_blocks) + for col in range( + max(0, row - self.num_local_blocks), (row + 1 if self.causal else end) + ): + layout[:, row, col] = 1 + return layout + + def set_global_layout(self, layout): + num_blocks = self.num_blocks + n_dim = len(self._block_shape) + for row in range(num_blocks): + assert self._to_flattened_idx(self._to_unflattened_idx(row)) == row + cur_idx = self._to_unflattened_idx(row) + # no strided attention over last dim + for d in range(n_dim - 1): + end = self._block_shape[d] + for i in range(0, (cur_idx[d] + 1 if self.causal else end)): + new_idx = list(cur_idx) + new_idx[d] = i + new_idx = tuple(new_idx) + + col = self._to_flattened_idx(new_idx) + layout[:, row, col] = 1 + + return layout + + def make_layout(self): + layout = torch.zeros( + (self.n_head, self.num_blocks, self.num_blocks), dtype=torch.int64 + ) + layout = self.set_local_layout(layout) + layout = self.set_global_layout(layout) + return layout + + def make_sparse_attn_mask(self): + block_layout = self.make_layout() + assert block_layout.shape[1] == block_layout.shape[2] == self.num_blocks + + num_dense_blocks = block_layout.sum().item() + attn_mask = torch.ones(num_dense_blocks, self.block, self.block) + counter = 0 + for h in range(self.n_head): + for i in range(self.num_blocks): + for j in range(self.num_blocks): + elem = block_layout[h, i, j].item() + if elem == 1: + assert i >= j + if i == j: # need to mask within block on diagonals + attn_mask[counter] = torch.tril(attn_mask[counter]) + counter += 1 + assert counter == num_dense_blocks + + return attn_mask.unsqueeze(0) + + def get_non_block_layout_row(self, block_layout, row): + block_row = row // self.block + block_row = block_layout[:, [block_row]] # n_head x 1 x n_blocks + block_row = block_row.repeat_interleave(self.block, dim=-1) + block_row[:, :, row + 1 :] = 0.0 + return block_row + + ############# Helper functions ########################## + + def _compute_block_shape(self): + n_dim = len(self.shape) + cum_prod = 1 + for i in range(n_dim - 1, -1, -1): + cum_prod *= self.shape[i] + if cum_prod > self.block: + break + assert cum_prod % self.block == 0 + new_shape = (*self.shape[:i], cum_prod // self.block) + + assert np.prod(new_shape) == np.prod(self.shape) // self.block + + return new_shape + + def _block_shape_cum_sizes(self): + bs = np.flip(np.array(self._block_shape)) + return tuple(np.flip(np.cumprod(bs)[:-1])) + (1,) + + def _to_flattened_idx(self, idx): + assert len(idx) == len( + self._block_shape + ), f"{len(idx)} != {len(self._block_shape)}" + flat_idx = 0 + for i in range(len(self._block_shape)): + flat_idx += idx[i] * self._block_shape_cum[i] + return flat_idx + + def _to_unflattened_idx(self, flat_idx): + assert flat_idx < np.prod(self._block_shape) + idx = [] + for i in range(len(self._block_shape)): + idx.append(flat_idx // self._block_shape_cum[i]) + flat_idx %= self._block_shape_cum[i] + return tuple(idx) + +# Copied from https://github.com/wilson1yan/VideoGPT +class SparseAttention(nn.Module): + ops = dict() + attn_mask = dict() + block_layout = dict() + + def __init__( + self, shape, n_head, causal, num_local_blocks=4, block=32, attn_dropout=0.0 + ): # does not use attn_dropout + super().__init__() + self.causal = causal + self.shape = shape + + self.sparsity_config = StridedSparsityConfig( + shape=shape, + n_head=n_head, + causal=causal, + block=block, + num_local_blocks=num_local_blocks, + ) + + if self.shape not in SparseAttention.block_layout: + SparseAttention.block_layout[self.shape] = ( + self.sparsity_config.make_layout() + ) + if causal and self.shape not in SparseAttention.attn_mask: + SparseAttention.attn_mask[self.shape] = ( + self.sparsity_config.make_sparse_attn_mask() + ) + + def get_ops(self): + try: + from deepspeed.ops.sparse_attention import MatMul, Softmax + except: + raise Exception( + "Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`" + ) + if self.shape not in SparseAttention.ops: + sparsity_layout = self.sparsity_config.make_layout() + sparse_dot_sdd_nt = MatMul( + sparsity_layout, + self.sparsity_config.block, + "sdd", + trans_a=False, + trans_b=True, + ) + + sparse_dot_dsd_nn = MatMul( + sparsity_layout, + self.sparsity_config.block, + "dsd", + trans_a=False, + trans_b=False, + ) + + sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block) + + SparseAttention.ops[self.shape] = ( + sparse_dot_sdd_nt, + sparse_dot_dsd_nn, + sparse_softmax, + ) + return SparseAttention.ops[self.shape] + + def forward(self, q, k, v, decode_step, decode_idx): + if self.training and self.shape not in SparseAttention.ops: + self.get_ops() + + SparseAttention.block_layout[self.shape] = SparseAttention.block_layout[ + self.shape + ].to(q) + if self.causal: + SparseAttention.attn_mask[self.shape] = ( + SparseAttention.attn_mask[self.shape].to(q).type_as(q) + ) + attn_mask = SparseAttention.attn_mask[self.shape] if self.causal else None + + old_shape = q.shape[2:-1] + q = q.flatten(start_dim=2, end_dim=-2) + k = k.flatten(start_dim=2, end_dim=-2) + v = v.flatten(start_dim=2, end_dim=-2) + + if decode_step is not None: + mask = self.sparsity_config.get_non_block_layout_row( + SparseAttention.block_layout[self.shape], decode_step + ) + out = scaled_dot_product_attention( + q, k, v, mask=mask, training=self.training + ) + else: + if q.shape != k.shape or k.shape != v.shape: + raise Exception("SparseAttention only support self-attention") + sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops() + scaling = float(q.shape[-1]) ** -0.5 + + attn_output_weights = sparse_dot_sdd_nt(q, k) + if attn_mask is not None: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask == 0, float("-inf") + ) + attn_output_weights = sparse_softmax(attn_output_weights, scale=scaling) + + out = sparse_dot_dsd_nn(attn_output_weights, v) + + return view_range(out, 2, 3, old_shape) + +class CausalVQVAEModel(VideoBaseAE): + + def __init__(self, config: CausalVQVAEConfiguration): + super().__init__() + self.config = config + self.embedding_dim = config.embedding_dim + self.n_codes = config.n_codes + self.encoder = Encoder(config.n_hiddens, config.n_res_layers, config.time_downsample, config.spatial_downsample) + self.decoder = Decoder(config.n_hiddens, config.n_res_layers, config.time_downsample, config.spatial_downsample) + self.pre_vq_conv = CausalConv3d(config.n_hiddens, config.embedding_dim, 1) + self.post_vq_conv = CausalConv3d(config.embedding_dim, config.n_hiddens, 1) + self.codebook = Codebook(config.n_codes, config.embedding_dim) + + def forward(self, x): + z = self.pre_vq_conv(self.encoder(x)) + vq_output = self.codebook(z) + x_recon = self.decoder(self.post_vq_conv(vq_output["embeddings"])) + recon_loss = F.mse_loss(x_recon, x) / 0.06 + return recon_loss, x_recon, vq_output + + def encode(self, x: Tensor, include_embeddings: bool = False) -> Union[Tuple[Tensor, Tensor], Tensor]: + h = self.pre_vq_conv(self.encoder(x)) + vq_output: Dict[str, Tensor] = self.codebook(h) + if include_embeddings: + return vq_output["encodings"], vq_output["embeddings"] + else: + return vq_output["encodings"] + + def decode(self, encodings: Tensor) -> Tensor: + h = F.embedding(encodings, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + @classmethod + def load_from_checkpoint(cls, model_path): + with open(os.path.join(model_path, "config.json"), "r") as file: + config = json.load(file) + state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") + model = cls(config=CausalVQVAEConfiguration(**config)) + model.load_state_dict(state_dict) + return model + + @classmethod + def download_and_load_model(cls, model_name, cache_dir=None): + raise NotImplementedError() diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/trainer_causalvqvae.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/trainer_causalvqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..f819bce363c4cfb10ec7bff3a3e754a30e09dec7 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/causal_vqvae/trainer_causalvqvae.py @@ -0,0 +1,21 @@ +from ..trainer_videobase import VideoBaseTrainer +import torch.nn.functional as F +from typing import Optional +import os +import torch +from transformers.utils import WEIGHTS_NAME +import json + +class CausalVQVAETrainer(VideoBaseTrainer): + + def compute_loss(self, model, inputs, return_outputs=False): + model = model.module + x = inputs.get("video") + x = x / 2 + z = model.pre_vq_conv(model.encoder(x)) + vq_output = model.codebook(z) + x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"])) + recon_loss = F.mse_loss(x_recon, x) / 0.06 + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + return loss diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/configuration_videobase.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/configuration_videobase.py new file mode 100644 index 0000000000000000000000000000000000000000..f25a25c5b37da8aa8f9e939112ced0b9f861a4b6 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/configuration_videobase.py @@ -0,0 +1,44 @@ +import json +import yaml +from typing import TypeVar, Dict, Any +from diffusers import ConfigMixin + +T = TypeVar('T', bound='VideoBaseConfiguration') +class VideoBaseConfiguration(ConfigMixin): + config_name = "VideoBaseConfiguration" + _nested_config_fields: Dict[str, Any] = {} + + def __init__(self, **kwargs): + pass + + def to_dict(self) -> Dict[str, Any]: + d = {} + for key, value in vars(self).items(): + if isinstance(value, VideoBaseConfiguration): + d[key] = value.to_dict() # Serialize nested VideoBaseConfiguration instances + elif isinstance(value, tuple): + d[key] = list(value) + else: + d[key] = value + return d + + def to_yaml_file(self, yaml_path: str): + with open(yaml_path, 'w') as yaml_file: + yaml.dump(self.to_dict(), yaml_file, default_flow_style=False) + + @classmethod + def load_from_yaml(cls: T, yaml_path: str) -> T: + with open(yaml_path, 'r') as yaml_file: + config_dict = yaml.safe_load(yaml_file) + for field, field_type in cls._nested_config_fields.items(): + if field in config_dict: + config_dict[field] = field_type.load_from_dict(config_dict[field]) + return cls(**config_dict) + + @classmethod + def load_from_dict(cls: T, config_dict: Dict[str, Any]) -> T: + # Process nested configuration objects + for field, field_type in cls._nested_config_fields.items(): + if field in config_dict: + config_dict[field] = field_type.load_from_dict(config_dict[field]) + return cls(**config_dict) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/dataset_videobase.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/dataset_videobase.py new file mode 100644 index 0000000000000000000000000000000000000000..32f842f63a310c9ee8e2d1dde3c3e695bdaa582a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/dataset_videobase.py @@ -0,0 +1,107 @@ +import os.path as osp +import random +from glob import glob + +from torchvision import transforms +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F +from torchvision.transforms import Lambda + +from ....dataset.transform import ToTensorVideo, CenterCropVideo +from ....utils.dataset_utils import DecordInit + +def TemporalRandomCrop(total_frames, size): + """ + Performs a random temporal crop on a video sequence. + + This function randomly selects a continuous frame sequence of length `size` from a video sequence. + `total_frames` indicates the total number of frames in the video sequence, and `size` represents the length of the frame sequence to be cropped. + + Parameters: + - total_frames (int): The total number of frames in the video sequence. + - size (int): The length of the frame sequence to be cropped. + + Returns: + - (int, int): A tuple containing two integers. The first integer is the starting frame index of the cropped sequence, + and the second integer is the ending frame index (inclusive) of the cropped sequence. + """ + rand_end = max(0, total_frames - size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + size, total_frames) + return begin_index, end_index + +def resize(x, resolution): + height, width = x.shape[-2:] + resolution = min(2 * resolution, height, width) + aspect_ratio = width / height + if width <= height: + new_width = resolution + new_height = int(resolution / aspect_ratio) + else: + new_height = resolution + new_width = int(resolution * aspect_ratio) + resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True) + return resized_x + +class VideoDataset(data.Dataset): + """ Generic dataset for videos files stored in folders + Returns BCTHW videos in the range [-0.5, 0.5] """ + video_exts = ['avi', 'mp4', 'webm'] + def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64, sample_rate=1, dynamic_sample=True): + + self.train = train + self.sequence_length = sequence_length + self.sample_rate = sample_rate + self.resolution = resolution + self.v_decoder = DecordInit() + self.video_folder = video_folder + self.dynamic_sample = dynamic_sample + + self.transform = transforms.Compose([ + ToTensorVideo(), + # Lambda(lambda x: resize(x, self.resolution)), + CenterCropVideo(self.resolution), + Lambda(lambda x: 2.0 * x - 1.0) + ]) + print('Building datasets...') + self.samples = self._make_dataset() + + def _make_dataset(self): + samples = [] + samples += sum([glob(osp.join(self.video_folder, '**', f'*.{ext}'), recursive=True) + for ext in self.video_exts], []) + return samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + video_path = self.samples[idx] + try: + video = self.decord_read(video_path) + video = self.transform(video) # T C H W -> T C H W + video = video.transpose(0, 1) # T C H W -> C T H W + return dict(video=video, label="") + except Exception as e: + print(f'Error with {e}, {video_path}') + return self.__getitem__(random.randint(0, self.__len__()-1)) + + def decord_read(self, path): + decord_vr = self.v_decoder(path) + total_frames = len(decord_vr) + # Sampling video frames + if self.dynamic_sample: + sample_rate = random.randint(1, self.sample_rate) + else: + sample_rate = self.sample_rate + size = self.sequence_length * sample_rate + start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size) + # assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int) + + video_data = decord_vr.get_batch(frame_indice).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) + return video_data \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91e59504e7094dc422bba4f3d12f02305ce9b30f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/__init__.py @@ -0,0 +1 @@ +from .perceptual_loss import SimpleLPIPS, LPIPSWithDiscriminator, LPIPSWithDiscriminator3D \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/discriminator.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..247a610706f471664eb5e478128d3aa973150877 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/discriminator.py @@ -0,0 +1,181 @@ +import functools +import torch.nn as nn +from ..modules.normalize import ActNorm +from ..modules.conv import CausalConv3d +from einops import rearrange + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + +def weights_init_conv(m): + if hasattr(m, 'conv'): + m = m.conv + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + +class NLayerDiscriminator3D(nn.Module): + """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" + def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): + """ + Construct a 3D PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input volumes + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + use_actnorm (bool) -- flag to use actnorm instead of batchnorm + """ + super(NLayerDiscriminator3D, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm3d + else: + raise NotImplementedError("Not implemented.") + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func != nn.BatchNorm3d + else: + use_bias = norm_layer != nn.BatchNorm3d + + kw = 3 + padw = 1 + sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + + + + +# class NLayerDiscriminator3D(nn.Module): +# """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" +# def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): +# """ +# Construct a 3D PatchGAN discriminator + +# Parameters: +# input_nc (int) -- the number of channels in input volumes +# ndf (int) -- the number of filters in the last conv layer +# n_layers (int) -- the number of conv layers in the discriminator +# use_actnorm (bool) -- flag to use actnorm instead of batchnorm +# """ +# super(NLayerDiscriminator3D, self).__init__() +# if not use_actnorm: +# norm_layer = nn.BatchNorm3d +# else: +# raise NotImplementedError("Not implemented.") +# if type(norm_layer) == functools.partial: +# use_bias = norm_layer.func != nn.BatchNorm3d +# else: +# use_bias = norm_layer != nn.BatchNorm3d + +# kw = 4 +# padw = 1 +# sequence = [CausalConv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] +# nf_mult = 1 +# nf_mult_prev = 1 +# for n in range(1, n_layers): # gradually increase the number of filters +# nf_mult_prev = nf_mult +# nf_mult = min(2 ** n, 8) +# sequence += [ +# CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias), +# norm_layer(ndf * nf_mult), +# nn.LeakyReLU(0.2, True) +# ] + +# nf_mult_prev = nf_mult +# nf_mult = min(2 ** n_layers, 8) +# sequence += [ +# CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias), +# norm_layer(ndf * nf_mult), +# nn.LeakyReLU(0.2, True) +# ] + +# sequence += [CausalConv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map +# self.main = nn.Sequential(*sequence) + +# def forward(self, input): +# """Standard forward.""" +# return self.main(input) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/lpips.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7062cdd0e9b65e6eb268a94ab3fe139e074bb3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/lpips.py @@ -0,0 +1,120 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple +from .....utils.taming_download import get_ckpt_path + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/perceptual_loss.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e4042d0528b6c175581e79934c46f468dc95f0c4 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/losses/perceptual_loss.py @@ -0,0 +1,414 @@ +import torch +from torch import nn +import torch.nn.functional as F +from .lpips import LPIPS +from einops import rearrange +from .discriminator import NLayerDiscriminator, weights_init, NLayerDiscriminator3D + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) + return d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3]) + loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def adopt_weight(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + + +def l1(x, y): + return torch.abs(x - y) + + +def l2(x, y): + return torch.pow((x - y), 2) + + +class LPIPSWithDiscriminator(nn.Module): + def __init__( + self, + disc_start, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + perceptual_weight=1.0, + # --- Discriminator Loss --- + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss="hinge", + ): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad( + nll_loss, self.last_layer[0], retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, self.last_layer[0], retain_graph=True + )[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + inputs, + reconstructions, + posteriors, + optimizer_idx, + global_step, + split="train", + weights=None, + last_layer=None, + cond=None, + ): + inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous() + reconstructions = rearrange( + reconstructions, "b c t h w -> (b t) c h w" + ).contiguous() + rec_loss = torch.abs(inputs - reconstructions) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs, reconstructions) + rec_loss = rec_loss + self.perceptual_weight * p_loss + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # GAN Part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous(), cond), dim=1) + ) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + loss = ( + weighted_nll_loss + + self.kl_weight * kl_loss + + d_weight * disc_factor * g_loss + ) + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1) + ) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous().detach(), cond), dim=1) + ) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } + return d_loss, log + + +class LPIPSWithDiscriminator3D(nn.Module): + def __init__( + self, + disc_start, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + perceptual_weight=1.0, + # --- Discriminator Loss --- + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss="hinge", + ): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator3D( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad( + nll_loss, self.last_layer[0], retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, self.last_layer[0], retain_graph=True + )[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + inputs, + reconstructions, + posteriors, + optimizer_idx, + global_step, + split="train", + weights=None, + last_layer=None, + cond=None, + ): + t = inputs.shape[2] + inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous() + reconstructions = rearrange( + reconstructions, "b c t h w -> (b t) c h w" + ).contiguous() + rec_loss = torch.abs(inputs - reconstructions) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs, reconstructions) + rec_loss = rec_loss + self.perceptual_weight * p_loss + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + inputs = rearrange(inputs, "(b t) c h w -> b c t h w", t=t).contiguous() + reconstructions = rearrange( + reconstructions, "(b t) c h w -> b c t h w", t=t + ).contiguous() + # GAN Part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions) + else: + assert self.disc_conditional + logits_fake = self.discriminator( + torch.cat((reconstructions, cond), dim=1) + ) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + except RuntimeError as e: + assert not self.training, print(e) + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + loss = ( + weighted_nll_loss + + self.kl_weight * kl_loss + + d_weight * disc_factor * g_loss + ) + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1) + ) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous().detach(), cond), dim=1) + ) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } + return d_loss, log + + +class SimpleLPIPS(nn.Module): + def __init__( + self, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + perceptual_weight=1.0, + disc_loss="hinge", + ): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + def forward( + self, + inputs, + reconstructions, + posteriors, + split="train", + weights=None, + ): + inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous() + reconstructions = rearrange( + reconstructions, "b c t h w -> (b t) c h w" + ).contiguous() + rec_loss = torch.abs(inputs - reconstructions) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs, reconstructions) + rec_loss = rec_loss + self.perceptual_weight * p_loss + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + loss = weighted_nll_loss + self.kl_weight * kl_loss + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + } + if self.perceptual_weight > 0: + log.update({"{}/p_loss".format(split): p_loss.detach().mean()}) + return loss, log diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modeling_videobase.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modeling_videobase.py new file mode 100644 index 0000000000000000000000000000000000000000..0b2274ec52155a5079d93188b1493138ad554d37 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modeling_videobase.py @@ -0,0 +1,80 @@ +import torch +from diffusers import ModelMixin, ConfigMixin +from torch import nn +import os +import json +import pytorch_lightning as pl +from diffusers.configuration_utils import ConfigMixin +from diffusers.models.modeling_utils import ModelMixin +from typing import Optional, Union +import glob + +class VideoBaseAE(nn.Module): + _supports_gradient_checkpointing = False + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def load_from_checkpoint(cls, model_path): + with open(os.path.join(model_path, "config.json"), "r") as file: + config = json.load(file) + state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + model = cls(config=cls.CONFIGURATION_CLS(**config)) + model.load_state_dict(state_dict) + return model + + @classmethod + def download_and_load_model(cls, model_name, cache_dir=None): + pass + + def encode(self, x: torch.Tensor, *args, **kwargs): + pass + + def decode(self, encoding: torch.Tensor, *args, **kwargs): + pass + +class VideoBaseAE_PL(pl.LightningModule, ModelMixin, ConfigMixin): + config_name = "config.json" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def encode(self, x: torch.Tensor, *args, **kwargs): + pass + + def decode(self, encoding: torch.Tensor, *args, **kwargs): + pass + + @property + def num_training_steps(self) -> int: + """Total training steps inferred from datamodule and devices.""" + if self.trainer.max_steps: + return self.trainer.max_steps + + limit_batches = self.trainer.limit_train_batches + batches = len(self.train_dataloader()) + batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) + + num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) + if self.trainer.tpu_cores: + num_devices = max(num_devices, self.trainer.tpu_cores) + + effective_accum = self.trainer.accumulate_grad_batches * num_devices + return (batches // effective_accum) * self.trainer.max_epochs + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, '*.ckpt')) + if ckpt_files: + # Adapt to PyTorch Lightning + last_ckpt_file = ckpt_files[-1] + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + model = cls.from_config(config_file) + print("init from {}".format(last_ckpt_file)) + model.init_from_ckpt(last_ckpt_file) + return model + else: + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61ca11dec7fa1467d7b0e8a4304662d0b5d977d7 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/__init__.py @@ -0,0 +1,24 @@ +from .block import Block +from .attention import ( + AttnBlock3D, + AttnBlock3DFix, + AttnBlock, + LinAttnBlock, + LinearAttention, + TemporalAttnBlock +) +from .conv import CausalConv3d, Conv2d +from .normalize import GroupNorm, Normalize +from .resnet_block import ResnetBlock2D, ResnetBlock3D +from .updownsample import ( + SpatialDownsample2x, + SpatialUpsample2x, + TimeDownsample2x, + TimeUpsample2x, + Upsample, + Downsample, + TimeDownsampleRes2x, + TimeUpsampleRes2x, + TimeDownsampleResAdv2x, + TimeUpsampleResAdv2x +) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/attention.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b9bb19b83902541638409d5f817362dffc9c0f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/attention.py @@ -0,0 +1,230 @@ +import torch.nn as nn +from .normalize import Normalize +from .conv import CausalConv3d +import torch +import numpy as np +from einops import rearrange +from .block import Block +from .ops import video_to_image + +class LinearAttention(Block): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock3D(Block): + """Compatible with old versions, there are issues, use with caution.""" + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, t, h, w = q.shape + q = q.reshape(b * t, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b * t, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b * t, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, t, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + +class AttnBlock3DFix(nn.Module): + """ + Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. + """ + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c) + b, c, t, h, w = q.shape + q = q.permute(0, 2, 1, 3, 4) + q = q.reshape(b * t, c, h * w) + q = q.permute(0, 2, 1) + + # k: (b c t h w) -> (b t c h w) -> (b*t c h*w) + k = k.permute(0, 2, 1, 3, 4) + k = k.reshape(b * t, c, h * w) + + # w: (b*t hw hw) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + # v: (b c t h w) -> (b t c h w) -> (bt c hw) + # w_: (bt hw hw) -> (bt hw hw) + v = v.permute(0, 2, 1, 3, 4) + v = v.reshape(b * t, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + + # h_: (b*t c hw) -> (b t c h w) -> (b c t h w) + h_ = h_.reshape(b, t, c, h, w) + h_ = h_.permute(0, 2, 1, 3 ,4) + + h_ = self.proj_out(h_) + + return x + h_ + + +class AttnBlock(Block): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + @video_to_image + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class TemporalAttnBlock(Block): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, t, h, w = q.shape + q = rearrange(q, "b c t h w -> (b h w) t c") + k = rearrange(k, "b c t h w -> (b h w) c t") + v = rearrange(v, "b c t h w -> (b h w) c t") + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w) + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + print(attn_type) + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "vanilla3D": + return AttnBlock3D(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/block.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/block.py new file mode 100644 index 0000000000000000000000000000000000000000..e93672d20b22cbaaa36b379473149c55f0f44001 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/block.py @@ -0,0 +1,5 @@ +import torch.nn as nn + +class Block(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/conv.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4c8ae279095cf8f450cdaab32ba4d5f1774c09 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/conv.py @@ -0,0 +1,98 @@ +import torch.nn as nn +from typing import Union, Tuple +import torch.nn.functional as F +import torch +from .block import Block +from .ops import cast_tuple +from einops import rearrange +from .ops import video_to_image + +class Conv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]] = 3, + stride: Union[int, Tuple[int]] = 1, + padding: Union[str, int, Tuple[int]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + @video_to_image + def forward(self, x): + return super().forward(x) + + +class CausalConv3d(nn.Module): + def __init__( + self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = self.kernel_size[0] + self.chan_in = chan_in + self.chan_out = chan_out + stride = kwargs.pop("stride", 1) + padding = kwargs.pop("padding", 0) + padding = list(cast_tuple(padding, 3)) + padding[0] = 0 + stride = cast_tuple(stride, 3) + self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding) + self._init_weights(init_method) + + def _init_weights(self, init_method): + ks = torch.tensor(self.kernel_size) + if init_method == "avg": + assert ( + self.kernel_size[1] == 1 and self.kernel_size[2] == 1 + ), "only support temporal up/down sample" + assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" + weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) + + eyes = torch.concat( + [ + torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + ], + dim=-1, + ) + weight[:, :, :, 0, 0] = eyes + + self.conv.weight = nn.Parameter( + weight, + requires_grad=True, + ) + elif init_method == "zero": + self.conv.weight = nn.Parameter( + torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)), + requires_grad=True, + ) + if self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + # 1 + 16 16 as video, 1 as image + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) # b c t h w + x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 + return self.conv(x) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/normalize.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8c05f0fa3459214dd077e09a157c588adb637b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/normalize.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from .block import Block + +class GroupNorm(Block): + def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.norm = torch.nn.GroupNorm( + num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True + ) + def forward(self, x): + return self.norm(x) + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/ops.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd262ad7145eb343838561f119767534761d4d3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/ops.py @@ -0,0 +1,40 @@ +import torch +from einops import rearrange + +def video_to_image(func): + def wrapper(self, x, *args, **kwargs): + if x.dim() == 5: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = func(self, x, *args, **kwargs) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + return wrapper + +def nonlinearity(x): + return x * torch.sigmoid(x) + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + dims = list(range(n_dims)) + del dims[src_dim] + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/quant.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..bb702cee5f594e73be3efe9f573f1fbb8032e70c --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/quant.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +import torch.distributed as dist +import numpy as np +import torch.nn.functional as F +from .ops import shift_dim + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim): + super().__init__() + self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) + self.register_buffer("N", torch.zeros(n_codes)) + self.register_buffer("z_avg", self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + y = self._tile(flat_inputs) + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + distances = ( + (flat_inputs**2).sum(dim=1, keepdim=True) + - 2 * flat_inputs @ self.embeddings.t() + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) + ) + + encoding_indices = torch.argmin(distances, dim=1) + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) + encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) + + embeddings = F.embedding(encoding_indices, self.embeddings) + embeddings = shift_dim(embeddings, -1, 1) + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) + encode_sum = flat_inputs.t() @ encode_onehot + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + usage = (self.N.view(self.n_codes, 1) >= 1).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return dict( + embeddings=embeddings_st, + encodings=encoding_indices, + commitment_loss=commitment_loss, + perplexity=perplexity, + ) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/resnet_block.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/resnet_block.py new file mode 100644 index 0000000000000000000000000000000000000000..189766a5bfc9c1943dab284ef808496710b24994 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/resnet_block.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from einops import rearrange, pack, unpack +from .normalize import Normalize +from .ops import nonlinearity, video_to_image +from .conv import CausalConv3d +from .block import Block + +class ResnetBlock2D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + @video_to_image + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + x = x + h + return x + +class ResnetBlock3D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) + else: + self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/updownsample.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/updownsample.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3d489aeee6a02eb4d81b47125e70b8d954c003 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/modules/updownsample.py @@ -0,0 +1,236 @@ +from typing import Union, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from .resnet_block import ResnetBlock3D +from .attention import TemporalAttnBlock +from .normalize import Normalize +from .ops import cast_tuple, video_to_image +from .conv import CausalConv3d +from einops import rearrange +from .block import Block + +class Upsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.with_conv = True + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + @video_to_image + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +class Downsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.with_conv = True + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=0) + @video_to_image + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class SpatialDownsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (2, 2), + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 2) + stride = cast_tuple(stride, 2) + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d( + self.chan_in, + self.chan_out, + (1,) + self.kernel_size, + stride=(1, ) + stride, + padding=0 + ) + + def forward(self, x): + pad = (0,1,0,1,0,0) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class SpatialUpsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (1, 1), + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d( + self.chan_in, + self.chan_out, + (1,) + self.kernel_size, + stride=(1, ) + stride, + padding=1 + ) + + def forward(self, x): + t = x.shape[2] + x = rearrange(x, "b c t h w -> b (c t) h w") + x = F.interpolate(x, scale_factor=(2,2), mode="nearest") + x = rearrange(x, "b (c t) h w -> b c t h w", t=t) + x = self.conv(x) + return x + +class TimeDownsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: int = 3 + ): + super().__init__() + self.kernel_size = kernel_size + self.conv = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1)) + + def forward(self, x): + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.kernel_size - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + return self.conv(x) + +class TimeUpsample2x(Block): + def __init__( + self, + chan_in, + chan_out + ): + super().__init__() + def forward(self, x): + if x.size(2) > 1: + x,x_= x[:,:,:1],x[:,:,1:] + x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') + x = torch.concat([x, x_], dim=2) + return x + +class TimeDownsampleRes2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 2.0, + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.avg_pool = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1)) + self.conv = nn.Conv3d( + in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1) + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + alpha = torch.sigmoid(self.mix_factor) + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.kernel_size[0] - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x) + +class TimeUpsampleRes2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 2.0, + ): + super().__init__() + self.conv = CausalConv3d( + in_channels, out_channels, kernel_size, padding=1 + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + alpha = torch.sigmoid(self.mix_factor) + if x.size(2) > 1: + x,x_= x[:,:,:1],x[:,:,1:] + x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') + x = torch.concat([x, x_], dim=2) + return alpha * x + (1-alpha) * self.conv(x) + +class TimeDownsampleResAdv2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 1.5, + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.avg_pool = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1)) + self.attn = TemporalAttnBlock(in_channels) + self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) + self.conv = nn.Conv3d( + in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1) + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.kernel_size[0] - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + alpha = torch.sigmoid(self.mix_factor) + return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x)))) + +class TimeUpsampleResAdv2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 1.5, + ): + super().__init__() + self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) + self.attn = TemporalAttnBlock(in_channels) + self.norm = Normalize(in_channels=in_channels) + self.conv = CausalConv3d( + in_channels, out_channels, kernel_size, padding=1 + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + if x.size(2) > 1: + x,x_= x[:,:,:1],x[:,:,1:] + x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') + x = torch.concat([x, x_], dim=2) + alpha = torch.sigmoid(self.mix_factor) + return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x))) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/trainer_videobase.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/trainer_videobase.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c4b0c7d2a2fdec0488cfeedbfa9844416e4053 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/trainer_videobase.py @@ -0,0 +1,26 @@ +from transformers import Trainer +import torch.nn.functional as F +from typing import Optional +import os +import torch +from transformers.utils import WEIGHTS_NAME +import json + +class VideoBaseTrainer(Trainer): + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + if state_dict is None: + state_dict = self.model.state_dict() + + # get model config + model_config = self.model.config.to_dict() + + # add more information + model_config['model'] = self.model.__class__.__name__ + + with open(os.path.join(output_dir, "config.json"), "w") as file: + json.dump(self.model.config.to_dict(), file) + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/distrib_utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/distrib_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..760c0673fe5d8afa663eb1ea5cd7683dbf5dd9f8 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/distrib_utils.py @@ -0,0 +1,42 @@ +import torch +import numpy as np + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/module_utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/module_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07ce175f183895adbfcb4da34d405d1d308482c9 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/module_utils.py @@ -0,0 +1,17 @@ +import importlib + +Module = str +MODULES_BASE = "videogen_hub.pipelines.opensora_plan.opensora.models.ae.videobase.modules." + +def resolve_str_to_obj(str_val, append=True): + if append: + str_val = MODULES_BASE + str_val + module_name, class_name = str_val.rsplit('.', 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + +def create_instance(module_class_str: str, **kwargs): + module_name, class_name = module_class_str.rsplit('.', 1) + module = importlib.import_module(module_name) + class_ = getattr(module, class_name) + return class_(**kwargs) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/scheduler_utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/scheduler_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0650a3f79c3bc4f1bd5bf4a556995dc538b2b23a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/scheduler_utils.py @@ -0,0 +1,7 @@ +import torch + +def cosine_scheduler(step, max_steps, value_base=1, value_end=0): + step = torch.tensor(step) + cosine_value = 0.5 * (1 + torch.cos(torch.pi * step / max_steps)) + value = value_end + (value_base - value_end) * cosine_value + return value \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/video_utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a038fcca7526e6eab3f36f3f68d071f97e0357ec --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/utils/video_utils.py @@ -0,0 +1,10 @@ +import torch +import numpy as np + +def tensor_to_video(x): + x = x.detach().cpu() + x = torch.clamp(x, -1, 1) + x = (x + 1) / 2 + x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> + x = (255 * x).astype(np.uint8) + return x \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec138ecc722fd5a8a540f363ed383dcb10e93695 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/__init__.py @@ -0,0 +1,30 @@ +from einops import rearrange +from torch import nn + +from .configuration_vqvae import VQVAEConfiguration +from .modeling_vqvae import VQVAEModel +from .trainer_vqvae import VQVAETrainer + +videovqvae = [ + "bair_stride4x2x2", + "ucf101_stride4x4x4", + "kinetics_stride4x4x4", + "kinetics_stride2x4x4", +] +videovae = [] + +class VQVAEModelWrapper(nn.Module): + def __init__(self, ckpt='kinetics_stride4x4x4'): + super(VQVAEModelWrapper, self).__init__() + if ckpt in videovqvae: + self.vqvae = VQVAEModel.download_and_load_model(ckpt) + else: + self.vqvae = VQVAEModel.load_from_checkpoint(ckpt) + def encode(self, x): # b c t h w + x = self.vqvae.pre_vq_conv(self.vqvae.encoder(x)) + return x + def decode(self, x): + vq_output = self.vqvae.codebook(x) + x = self.vqvae.decoder(self.vqvae.post_vq_conv(vq_output['embeddings'])) + x = rearrange(x, 'b c t h w -> b t c h w').contiguous() + return x diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/configuration_vqvae.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/configuration_vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..90ac29cfa5f3c899b3c78b11868cfb30e4908812 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/configuration_vqvae.py @@ -0,0 +1,33 @@ +from ..configuration_videobase import VideoBaseConfiguration +from typing import Union, Tuple + +class VQVAEConfiguration(VideoBaseConfiguration): + def __init__( + self, + embedding_dim: int = 256, + n_codes: int = 2048, + n_hiddens: int = 240, + n_res_layers: int = 4, + resolution: int = 128, + sequence_length: int = 16, + downsample: Union[Tuple[int, int, int], str] = (4, 4, 4), + no_pos_embd: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.embedding_dim = embedding_dim + self.n_codes = n_codes + self.n_hiddens = n_hiddens + self.n_res_layers = n_res_layers + self.resolution = resolution + self.sequence_length = sequence_length + + if isinstance(downsample, str): + self.downsample = tuple(map(int, downsample.split(","))) + else: + self.downsample = downsample + + self.no_pos_embd = no_pos_embd + + self.hidden_size = n_hiddens diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/modeling_vqvae.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/modeling_vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..6a51677e6c947bc52ee991d6d7b0a62813583769 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/modeling_vqvae.py @@ -0,0 +1,775 @@ +from ..modeling_videobase import VideoBaseAE +import torch +from torch import nn, Tensor +import numpy as np +import torch.distributed as dist +import torch.nn.functional as F +import math +import os +import json +from typing import Tuple, Dict, Union +from .configuration_vqvae import VQVAEConfiguration + + +# Copied from https://github.com/wilson1yan/VideoGPT +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +# Copied from https://github.com/wilson1yan/VideoGPT +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + dims = list(range(n_dims)) + del dims[src_dim] + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# Copied from https://github.com/wilson1yan/VideoGPT +def scaled_dot_product_attention(q, k, v, mask=None, attn_dropout=0.0, training=True): + # Performs scaled dot-product attention over the second to last dimension dn + + # (b, n_head, d1, ..., dn, d) + attn = torch.matmul(q, k.transpose(-1, -2)) + attn = attn / np.sqrt(q.shape[-1]) + if mask is not None: + attn = attn.masked_fill(mask == 0, float("-inf")) + attn_float = F.softmax(attn, dim=-1) + attn = attn_float.type_as(attn) # b x n_head x d1 x ... x dn x d + attn = F.dropout(attn, p=attn_dropout, training=training) + + a = torch.matmul(attn, v) # b x n_head x d1 x ... x dn x d + + return a + + +# Copied from https://github.com/wilson1yan/VideoGPT +class AxialBlock(nn.Module): + def __init__(self, n_hiddens, n_head): + super().__init__() + kwargs = dict( + shape=(0,) * 3, + dim_q=n_hiddens, + dim_kv=n_hiddens, + n_head=n_head, + n_layer=1, + causal=False, + attn_type="axial", + ) + self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2), **kwargs) + self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3), **kwargs) + self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4), **kwargs) + + def forward(self, x): + x = shift_dim(x, 1, -1) + x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x) + x = shift_dim(x, -1, 1) + return x + + +# Copied from https://github.com/wilson1yan/VideoGPT +class AttentionResidualBlock(nn.Module): + def __init__(self, n_hiddens): + super().__init__() + self.block = nn.Sequential( + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + SamePadConv3d(n_hiddens, n_hiddens // 2, 3, bias=False), + nn.BatchNorm3d(n_hiddens // 2), + nn.ReLU(), + SamePadConv3d(n_hiddens // 2, n_hiddens, 1, bias=False), + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + AxialBlock(n_hiddens, 2), + ) + + def forward(self, x): + return x + self.block(x) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim): + super().__init__() + self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) + self.register_buffer("N", torch.zeros(n_codes)) + self.register_buffer("z_avg", self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + y = self._tile(flat_inputs) + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + distances = ( + (flat_inputs**2).sum(dim=1, keepdim=True) + - 2 * flat_inputs @ self.embeddings.t() + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) + ) + + encoding_indices = torch.argmin(distances, dim=1) + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) + encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) + + embeddings = F.embedding(encoding_indices, self.embeddings) + embeddings = shift_dim(embeddings, -1, 1) + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) + encode_sum = flat_inputs.t() @ encode_onehot + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + usage = (self.N.view(self.n_codes, 1) >= 1).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return dict( + embeddings=embeddings_st, + encodings=encoding_indices, + commitment_loss=commitment_loss, + perplexity=perplexity, + ) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings + + +# Copied from https://github.com/wilson1yan/VideoGPT +class Encoder(nn.Module): + def __init__(self, n_hiddens, n_res_layers, downsample): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.convs = nn.ModuleList() + max_ds = n_times_downsample.max() + for i in range(max_ds): + in_channels = 3 if i == 0 else n_hiddens + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + conv = SamePadConv3d(in_channels, n_hiddens, 4, stride=stride) + self.convs.append(conv) + n_times_downsample -= 1 + self.conv_last = SamePadConv3d(in_channels, n_hiddens, kernel_size=3) + + self.res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + + def forward(self, x): + h = x + for conv in self.convs: + h = F.relu(conv(h)) + h = self.conv_last(h) + h = self.res_stack(h) + return h + + +# Copied from https://github.com/wilson1yan/VideoGPT +class MultiHeadAttention(nn.Module): + def __init__( + self, shape, dim_q, dim_kv, n_head, n_layer, causal, attn_type, attn_kwargs + ): + super().__init__() + self.causal = causal + self.shape = shape + + self.d_k = dim_q // n_head + self.d_v = dim_kv // n_head + self.n_head = n_head + + self.w_qs = nn.Linear(dim_q, n_head * self.d_k, bias=False) # q + self.w_qs.weight.data.normal_(std=1.0 / np.sqrt(dim_q)) + + self.w_ks = nn.Linear(dim_kv, n_head * self.d_k, bias=False) # k + self.w_ks.weight.data.normal_(std=1.0 / np.sqrt(dim_kv)) + + self.w_vs = nn.Linear(dim_kv, n_head * self.d_v, bias=False) # v + self.w_vs.weight.data.normal_(std=1.0 / np.sqrt(dim_kv)) + + self.fc = nn.Linear(n_head * self.d_v, dim_q, bias=True) # c + self.fc.weight.data.normal_(std=1.0 / np.sqrt(dim_q * n_layer)) + + if attn_type == "full": + self.attn = FullAttention(shape, causal, **attn_kwargs) + elif attn_type == "axial": + assert not causal, "causal axial attention is not supported" + self.attn = AxialAttention(len(shape), **attn_kwargs) + elif attn_type == "sparse": + self.attn = SparseAttention(shape, n_head, causal, **attn_kwargs) + + self.cache = None + + def forward(self, q, k, v, decode_step=None, decode_idx=None): + """Compute multi-head attention + Args + q, k, v: a [b, d1, ..., dn, c] tensor or + a [b, 1, ..., 1, c] tensor if decode_step is not None + + Returns + The output after performing attention + """ + + # compute k, q, v + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + q = view_range(self.w_qs(q), -1, None, (n_head, d_k)) + k = view_range(self.w_ks(k), -1, None, (n_head, d_k)) + v = view_range(self.w_vs(v), -1, None, (n_head, d_v)) + + # b x n_head x seq_len x d + # (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d) + q = shift_dim(q, -2, 1) + k = shift_dim(k, -2, 1) + v = shift_dim(v, -2, 1) + + # fast decoding + if decode_step is not None: + if decode_step == 0: + if self.causal: + k_shape = (q.shape[0], n_head, *self.shape, self.d_k) + v_shape = (q.shape[0], n_head, *self.shape, self.d_v) + self.cache = dict( + k=torch.zeros(k_shape, dtype=k.dtype, device=q.device), + v=torch.zeros(v_shape, dtype=v.dtype, device=q.device), + ) + else: + # cache only once in the non-causal case + self.cache = dict(k=k.clone(), v=v.clone()) + if self.causal: + idx = ( + slice(None, None), + slice(None, None), + *[slice(i, i + 1) for i in decode_idx], + ) + self.cache["k"][idx] = k + self.cache["v"][idx] = v + k, v = self.cache["k"], self.cache["v"] + + a = self.attn(q, k, v, decode_step, decode_idx) + + # (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d) + a = shift_dim(a, 1, -2).flatten(start_dim=-2) + a = self.fc(a) # (b x seq_len x embd_dim) + + return a + + +# Copied from https://github.com/wilson1yan/VideoGPT +class Decoder(nn.Module): + def __init__(self, n_hiddens, n_res_layers, upsample): + super().__init__() + self.res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + self.convts = nn.ModuleList() + for i in range(max_us): + out_channels = 3 if i == max_us - 1 else n_hiddens + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + convt = SamePadConvTranspose3d(n_hiddens, out_channels, 4, stride=us) + self.convts.append(convt) + n_times_upsample -= 1 + + def forward(self, x): + h = self.res_stack(x) + for i, convt in enumerate(self.convts): + h = convt(h) + if i < len(self.convts) - 1: + h = F.relu(h) + return h + + +# Copied from https://github.com/wilson1yan/VideoGPT +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + + self.conv = nn.Conv3d( + in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias + ) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input)) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + + self.convt = nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + bias=bias, + padding=tuple([k - 1 for k in kernel_size]), + ) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input)) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class FullAttention(nn.Module): + def __init__(self, shape, causal, attn_dropout): + super().__init__() + self.causal = causal + self.attn_dropout = attn_dropout + + seq_len = np.prod(shape) + if self.causal: + self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len))) + + def forward(self, q, k, v, decode_step, decode_idx): + mask = self.mask if self.causal else None + if decode_step is not None and mask is not None: + mask = mask[[decode_step]] + + old_shape = q.shape[2:-1] + q = q.flatten(start_dim=2, end_dim=-2) + k = k.flatten(start_dim=2, end_dim=-2) + v = v.flatten(start_dim=2, end_dim=-2) + + out = scaled_dot_product_attention( + q, k, v, mask=mask, attn_dropout=self.attn_dropout, training=self.training + ) + + return view_range(out, 2, 3, old_shape) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class AxialAttention(nn.Module): + def __init__(self, n_dim, axial_dim): + super().__init__() + if axial_dim < 0: + axial_dim = 2 + n_dim + 1 + axial_dim + else: + axial_dim += 2 # account for batch, head, dim + self.axial_dim = axial_dim + + def forward(self, q, k, v, decode_step, decode_idx): + q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3) + k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3) + v = shift_dim(v, self.axial_dim, -2) + old_shape = list(v.shape) + v = v.flatten(end_dim=-3) + + out = scaled_dot_product_attention(q, k, v, training=self.training) + out = out.view(*old_shape) + out = shift_dim(out, -2, self.axial_dim) + return out + + +# Copied from https://github.com/wilson1yan/VideoGPT +class StridedSparsityConfig(object): + """ + Strided Sparse configuration specified in https://arxiv.org/abs/1904.10509 that + generalizes to arbitrary dimensions + """ + + def __init__(self, shape, n_head, causal, block, num_local_blocks): + self.n_head = n_head + self.shape = shape + self.causal = causal + self.block = block + self.num_local_blocks = num_local_blocks + + assert self.num_local_blocks >= 1, "Must have at least 1 local block" + assert self.seq_len % self.block == 0, "seq len must be divisible by block size" + + self._block_shape = self._compute_block_shape() + self._block_shape_cum = self._block_shape_cum_sizes() + + @property + def seq_len(self): + return np.prod(self.shape) + + @property + def num_blocks(self): + return self.seq_len // self.block + + def set_local_layout(self, layout): + num_blocks = self.num_blocks + for row in range(0, num_blocks): + end = min(row + self.num_local_blocks, num_blocks) + for col in range( + max(0, row - self.num_local_blocks), (row + 1 if self.causal else end) + ): + layout[:, row, col] = 1 + return layout + + def set_global_layout(self, layout): + num_blocks = self.num_blocks + n_dim = len(self._block_shape) + for row in range(num_blocks): + assert self._to_flattened_idx(self._to_unflattened_idx(row)) == row + cur_idx = self._to_unflattened_idx(row) + # no strided attention over last dim + for d in range(n_dim - 1): + end = self._block_shape[d] + for i in range(0, (cur_idx[d] + 1 if self.causal else end)): + new_idx = list(cur_idx) + new_idx[d] = i + new_idx = tuple(new_idx) + + col = self._to_flattened_idx(new_idx) + layout[:, row, col] = 1 + + return layout + + def make_layout(self): + layout = torch.zeros( + (self.n_head, self.num_blocks, self.num_blocks), dtype=torch.int64 + ) + layout = self.set_local_layout(layout) + layout = self.set_global_layout(layout) + return layout + + def make_sparse_attn_mask(self): + block_layout = self.make_layout() + assert block_layout.shape[1] == block_layout.shape[2] == self.num_blocks + + num_dense_blocks = block_layout.sum().item() + attn_mask = torch.ones(num_dense_blocks, self.block, self.block) + counter = 0 + for h in range(self.n_head): + for i in range(self.num_blocks): + for j in range(self.num_blocks): + elem = block_layout[h, i, j].item() + if elem == 1: + assert i >= j + if i == j: # need to mask within block on diagonals + attn_mask[counter] = torch.tril(attn_mask[counter]) + counter += 1 + assert counter == num_dense_blocks + + return attn_mask.unsqueeze(0) + + def get_non_block_layout_row(self, block_layout, row): + block_row = row // self.block + block_row = block_layout[:, [block_row]] # n_head x 1 x n_blocks + block_row = block_row.repeat_interleave(self.block, dim=-1) + block_row[:, :, row + 1 :] = 0.0 + return block_row + + ############# Helper functions ########################## + + def _compute_block_shape(self): + n_dim = len(self.shape) + cum_prod = 1 + for i in range(n_dim - 1, -1, -1): + cum_prod *= self.shape[i] + if cum_prod > self.block: + break + assert cum_prod % self.block == 0 + new_shape = (*self.shape[:i], cum_prod // self.block) + + assert np.prod(new_shape) == np.prod(self.shape) // self.block + + return new_shape + + def _block_shape_cum_sizes(self): + bs = np.flip(np.array(self._block_shape)) + return tuple(np.flip(np.cumprod(bs)[:-1])) + (1,) + + def _to_flattened_idx(self, idx): + assert len(idx) == len( + self._block_shape + ), f"{len(idx)} != {len(self._block_shape)}" + flat_idx = 0 + for i in range(len(self._block_shape)): + flat_idx += idx[i] * self._block_shape_cum[i] + return flat_idx + + def _to_unflattened_idx(self, flat_idx): + assert flat_idx < np.prod(self._block_shape) + idx = [] + for i in range(len(self._block_shape)): + idx.append(flat_idx // self._block_shape_cum[i]) + flat_idx %= self._block_shape_cum[i] + return tuple(idx) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class SparseAttention(nn.Module): + ops = dict() + attn_mask = dict() + block_layout = dict() + + def __init__( + self, shape, n_head, causal, num_local_blocks=4, block=32, attn_dropout=0.0 + ): # does not use attn_dropout + super().__init__() + self.causal = causal + self.shape = shape + + self.sparsity_config = StridedSparsityConfig( + shape=shape, + n_head=n_head, + causal=causal, + block=block, + num_local_blocks=num_local_blocks, + ) + + if self.shape not in SparseAttention.block_layout: + SparseAttention.block_layout[self.shape] = ( + self.sparsity_config.make_layout() + ) + if causal and self.shape not in SparseAttention.attn_mask: + SparseAttention.attn_mask[self.shape] = ( + self.sparsity_config.make_sparse_attn_mask() + ) + + def get_ops(self): + try: + from deepspeed.ops.sparse_attention import MatMul, Softmax + except: + raise Exception( + "Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`" + ) + if self.shape not in SparseAttention.ops: + sparsity_layout = self.sparsity_config.make_layout() + sparse_dot_sdd_nt = MatMul( + sparsity_layout, + self.sparsity_config.block, + "sdd", + trans_a=False, + trans_b=True, + ) + + sparse_dot_dsd_nn = MatMul( + sparsity_layout, + self.sparsity_config.block, + "dsd", + trans_a=False, + trans_b=False, + ) + + sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block) + + SparseAttention.ops[self.shape] = ( + sparse_dot_sdd_nt, + sparse_dot_dsd_nn, + sparse_softmax, + ) + return SparseAttention.ops[self.shape] + + def forward(self, q, k, v, decode_step, decode_idx): + if self.training and self.shape not in SparseAttention.ops: + self.get_ops() + + SparseAttention.block_layout[self.shape] = SparseAttention.block_layout[ + self.shape + ].to(q) + if self.causal: + SparseAttention.attn_mask[self.shape] = ( + SparseAttention.attn_mask[self.shape].to(q).type_as(q) + ) + attn_mask = SparseAttention.attn_mask[self.shape] if self.causal else None + + old_shape = q.shape[2:-1] + q = q.flatten(start_dim=2, end_dim=-2) + k = k.flatten(start_dim=2, end_dim=-2) + v = v.flatten(start_dim=2, end_dim=-2) + + if decode_step is not None: + mask = self.sparsity_config.get_non_block_layout_row( + SparseAttention.block_layout[self.shape], decode_step + ) + out = scaled_dot_product_attention( + q, k, v, mask=mask, training=self.training + ) + else: + if q.shape != k.shape or k.shape != v.shape: + raise Exception("SparseAttention only support self-attention") + sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops() + scaling = float(q.shape[-1]) ** -0.5 + + attn_output_weights = sparse_dot_sdd_nt(q, k) + if attn_mask is not None: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask == 0, float("-inf") + ) + attn_output_weights = sparse_softmax(attn_output_weights, scale=scaling) + + out = sparse_dot_dsd_nn(attn_output_weights, v) + + return view_range(out, 2, 3, old_shape) + + +# Modified from https://github.com/wilson1yan/VideoGPT +class VQVAEModel(VideoBaseAE): + + DOWNLOADED_VQVAE = { + "bair_stride4x2x2": "1iIAYJ2Qqrx5Q94s5eIXQYJgAydzvT_8L", + "ucf101_stride4x4x4": "1uuB_8WzHP_bbBmfuaIV7PK_Itl3DyHY5", + "kinetics_stride4x4x4": "1DOvOZnFAIQmux6hG7pN_HkyJZy3lXbCB", + "kinetics_stride2x4x4": "1jvtjjtrtE4cy6pl7DK_zWFEPY3RZt2pB", + } + + def __init__(self, config: VQVAEConfiguration): + super().__init__() + self.config = config + self.embedding_dim = config.embedding_dim + self.n_codes = config.n_codes + self.encoder = Encoder(config.n_hiddens, config.n_res_layers, config.downsample) + self.decoder = Decoder(config.n_hiddens, config.n_res_layers, config.downsample) + self.pre_vq_conv = SamePadConv3d(config.n_hiddens, config.embedding_dim, 1) + self.post_vq_conv = SamePadConv3d(config.embedding_dim, config.n_hiddens, 1) + self.codebook = Codebook(config.n_codes, config.embedding_dim) + + def forward(self, x): + z = self.pre_vq_conv(self.encoder(x)) + vq_output = self.codebook(z) + x_recon = self.decoder(self.post_vq_conv(vq_output["embeddings"])) + recon_loss = F.mse_loss(x_recon, x) / 0.06 + return recon_loss, x_recon, vq_output + + def encode(self, x: Tensor, include_embeddings: bool = False) -> Union[Tuple[Tensor, Tensor], Tensor]: + h = self.pre_vq_conv(self.encoder(x)) + vq_output: Dict[str, Tensor] = self.codebook(h) + if include_embeddings: + return vq_output["encodings"], vq_output["embeddings"] + else: + return vq_output["encodings"] + + def decode(self, encodings: Tensor) -> Tensor: + h = F.embedding(encodings, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + @classmethod + def load_from_checkpoint(cls, model_path): + if not os.path.isdir(model_path): + """model downloaded from internet""" + model_cpkt = torch.load(model_path) + # Compatible with old videogpt model formats. + if "hyper_parameters" in model_cpkt: + hyper_parameters = vars(model_cpkt.get("hyper_parameters").get("args")) + state_dict = model_cpkt.get("state_dict") + model = cls(config=VQVAEConfiguration(**hyper_parameters)) + model.load_state_dict(state_dict) + return model + else: + raise RuntimeError("Model checkpoint has a wrong format.") + else: + with open(os.path.join(model_path, "config.json"), "r") as file: + config = json.load(file) + state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") + model = cls(config=VQVAEConfiguration(**config)) + model.load_state_dict(state_dict) + return model + + @classmethod + def download_and_load_model(cls, model_name, cache_dir=None): + from .....utils.downloader import gdown_download + path = gdown_download( + cls.DOWNLOADED_VQVAE[model_name], model_name, cache_dir=cache_dir + ) + return cls.load_from_checkpoint(path) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/trainer_vqvae.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/trainer_vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..df3f866eeeea3c50d8de372608ddfd795586efc6 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/ae/videobase/vqvae/trainer_vqvae.py @@ -0,0 +1,22 @@ +from ..trainer_videobase import VideoBaseTrainer +import torch.nn.functional as F +from typing import Optional +import os +import torch +from transformers.utils import WEIGHTS_NAME +import json + +class VQVAETrainer(VideoBaseTrainer): + + def compute_loss(self, model, inputs, return_outputs=False): + model = model.module + x = inputs.get("video") + x = x / 2 + z = model.pre_vq_conv(model.encoder(x)) + vq_output = model.codebook(z) + x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"])) + recon_loss = F.mse_loss(x_recon, x) / 0.06 + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + return loss + diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/README.md b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cf1ae266a19d1a80a609b6b3a5e789bfa4956f55 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/README.md @@ -0,0 +1,38 @@ +# Refiner for Video Caption + +Transform the short caption annotations from video datasets into the long and detailed caption annotations. + +* Add detailed description for background scene. +* Add detailed description for object attributes, including color, material, pose. +* Add detailed description for object-level spatial relationship. + +## 🛠️ Extra Requirements and Installation + +* openai == 0.28.0 +* jsonlines == 4.0.0 +* nltk == 3.8.1 +* Install the LLaMA-Accessory: + +you also need to download the weight of SPHINX to ./ckpt/ folder + +## 🗝️ Refining + +The refining instruction is in [demo_for_refiner.py](demo_for_refiner.py). + +```bash +python demo_for_refiner.py --root_path $path_to_repo$ --api_key $openai_api_key$ +``` + +### Refining Demos + +```bash +[original caption]: A red mustang parked in a showroom with american flags hanging from the ceiling. +``` + +```bash +[refine caption]: This scene depicts a red Mustang parked in a showroom with American flags hanging from the ceiling. The showroom likely serves as a space for showcasing and purchasing cars, and the Mustang is displayed prominently near the flags and ceiling. The scene also features a large window and other objects. Overall, it seems to take place in a car show or dealership. +``` + +- [ ] Add GPT-3.5-Turbo for caption summarization. ⌛ [WIP] +- [ ] Add LLAVA-1.6. ⌛ [WIP] +- [ ] More descriptions. ⌛ [WIP] \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/caption_refiner.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/caption_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..23952f6d9ce504945151619e4e6295360db67d00 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/caption_refiner.py @@ -0,0 +1,122 @@ +import itertools +import numpy as np +from PIL import Image +from PIL import ImageSequence +from nltk import pos_tag, word_tokenize + +from LLaMA2_Accessory.SPHINX import SPHINXModel +from gpt_combinator import caption_summary + +class CaptionRefiner(): + def __init__(self, sample_num, add_detect=True, add_pos=True, add_attr=True, + openai_api_key=None, openai_api_base=None, + ): + self.sample_num = sample_num + self.ADD_DETECTION_OBJ = add_detect + self.ADD_POS = add_pos + self.ADD_ATTR = add_attr + self.openai_api_key = openai_api_key + self.openai_api_base =openai_api_base + + def video_load_split(self, video_path=None): + frame_img_list, sampled_img_list = [], [] + + if ".gif" in video_path: + img = Image.open(video_path) + # process every frame in GIF from to + for frame in ImageSequence.Iterator(img): + frame_np = np.array(frame.copy().convert('RGB').getdata(),dtype=np.uint8).reshape(frame.size[1],frame.size[0],3) + frame_img = Image.fromarray(np.uint8(frame_np)) + frame_img_list.append(frame_img) + elif ".mp4" in video_path: + pass + + # sample frames from the mp4/gif + for i in range(0, len(frame_img_list), int(len(frame_img_list)/self.sample_num)): + sampled_img_list.append(frame_img_list[i]) + + return sampled_img_list # [, ...] + + def caption_refine(self, video_path, org_caption, model_path): + sampled_imgs = self.video_load_split(video_path) + + model = SPHINXModel.from_pretrained( + pretrained_path=model_path, + with_visual=True + ) + + existing_objects, scene_description = [], [] + text = word_tokenize(org_caption) + existing_objects = [word for word,tag in pos_tag(text) if tag in ["NN", "NNS", "NNP"]] + if self.ADD_DETECTION_OBJ: + # Detect the objects and scene in the sampled images + + qas = [["Where is this scene in the picture most likely to take place?", None]] + sc_response = model.generate_response(qas, sampled_imgs[0], max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) + scene_description.append(sc_response) + + # # Lacking accuracy + # for img in sampled_imgs: + # qas = [["Please detect the objects in the image.", None]] + # response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) + # print(response) + + object_attrs = [] + if self.ADD_ATTR: + # Detailed Description for all the objects in the sampled images + for obj in existing_objects: + obj_attr = [] + for img in sampled_imgs: + qas = [["Please describe the attribute of the {}, including color, position, etc".format(obj), None]] + response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) + obj_attr.append(response) + object_attrs.append({obj : obj_attr}) + + space_relations = [] + if self.ADD_POS: + obj_pairs = list(itertools.combinations(existing_objects, 2)) + # Description for the relationship between each object in the sample images + for obj_pair in obj_pairs: + qas = [["What is the spatial relationship between the {} and the {}? Please describe in lease than twenty words".format(obj_pair[0], obj_pair[1]), None]] + response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) + space_relations.append(response) + + return dict( + org_caption = org_caption, + scene_description = scene_description, + existing_objects = existing_objects, + object_attrs = object_attrs, + space_relations = space_relations, + ) + + def gpt_summary(self, total_captions): + # combine all captions into a detailed long caption + detailed_caption = "" + + if "org_caption" in total_captions.keys(): + detailed_caption += "In summary, "+ total_captions['org_caption'] + + if "scene_description" in total_captions.keys(): + detailed_caption += "We first describe the whole scene. "+total_captions['scene_description'][-1] + + if "existing_objects" in total_captions.keys(): + tmp_sentence = "There are multiple objects in the video, including " + for obj in total_captions['existing_objects']: + tmp_sentence += obj+", " + detailed_caption += tmp_sentence + + # if "object_attrs" in total_captions.keys(): + # caption_summary( + # caption_list="", + # api_key=self.openai_api_key, + # api_base=self.openai_api_base, + # ) + + if "space_relations" in total_captions.keys(): + tmp_sentence = "As for the spatial relationship. " + for sentence in total_captions['space_relations']: tmp_sentence += sentence + detailed_caption += tmp_sentence + + detailed_caption = caption_summary(detailed_caption, self.open_api_key, self.open_api_base) + + return detailed_caption \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/dataset/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/dataset/test_videos/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/dataset/test_videos/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/dataset/test_videos/captions.json b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/dataset/test_videos/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..098a352f2e3a6eaebf4ccf7885bb7b2718d44176 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/dataset/test_videos/captions.json @@ -0,0 +1 @@ +{"video1.gif": "A red mustang parked in a showroom with american flags hanging from the ceiling.", "video2.gif": "An aerial view of a city with a river running through it."} \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/demo_for_refiner.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/demo_for_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c0bfc5eb5b42b7da0e3f9ce13373c69cb39bae --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/demo_for_refiner.py @@ -0,0 +1,28 @@ +import argparse +from caption_refiner import CaptionRefiner +from gpt_combinator import caption_summary, caption_qa + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--root_path", required=True, help="The path to repo.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + myrefiner = CaptionRefiner( + sample_num=6, add_detect=True, add_pos=True, add_attr=True, + openai_api_key = args.api_key, + openai_api_base = "https://one-api.bltcy.top/v1", + ) + + results = myrefiner.caption_refine( + video_path="./dataset/test_videos/video1.gif", + org_caption="A red mustang parked in a showroom with american flags hanging from the ceiling.", + model_path = args.root_path + "/ckpts/SPHINX-Tiny", + ) + + final_caption = myrefiner.gpt_summary(results) + + print(final_caption) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/gpt_combinator.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/gpt_combinator.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a6f0dff9b2b198533c9741028a75961720fc6e --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/captioner/caption_refiner/gpt_combinator.py @@ -0,0 +1,93 @@ +import openai +import ast + +def caption_qa(caption_list, api_key, api_base): + openai.api_key = api_key + openai.api_base = api_base + + question = "What is the color of a red apple" + answer = "red" + pred = "green" + try: + # Compute the correctness score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + # model="gpt-4", + # model="gpt-4-vision-compatible", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " + "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + print(response_dict) + + except Exception as e: + print(f"Error processing file : {e}") + + +def caption_summary(long_caption, api_key, api_base): + """ + apply GPT3-Turbo as the combination for original caption and the prompted captions for a video + """ + openai.api_key = api_key + openai.api_base = api_base + + try: + # Compute the correctness score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for summarizing from a long sentence. " + }, + { + "role": "user", + "content": + "Please summarize the following sentences. Make it shorter than 70 words." + f"the long sentence: {long_caption}\n" + "Provide your summarization with less than 70 words. " + "DO NOT PROVIDE ANY OTHER TEXT OR EXPLANATION. Only provide the summary sentence. " + } + ] + ) + # "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + # "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + # "For example, your response should look like this: {'summary': 'your summary sentence'}." + + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + + except Exception as e: + print(f"Error processing file : {e}") + + return response_dict + +if __name__ == "__main__": + caption_summary() \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf3abffa98328aa734ae824b3902765debec8587 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/__init__.py @@ -0,0 +1,7 @@ + +from .latte.modeling_latte import Latte_models + +Diffusion_models = {} +Diffusion_models.update(Latte_models) + + \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04b2bd3d875d0e9ea9e0059b5f7dc3cea30795dc --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/__init__.py @@ -0,0 +1,87 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusion_T + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + # learn_sigma=False, + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + from . import gaussian_diffusion as gd + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) + +def create_diffusion_T( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + # learn_sigma=False, + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + from . import gaussian_diffusion_t2v as gd + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion_T( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/diffusion_utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/gaussian_diffusion.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc3d43a7fb627e0738d46c0f03cf1ab29b9258f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/gaussian_diffusion.py @@ -0,0 +1,881 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, F, C = x.shape[:3] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + # try: + # model_output = model_output.sample # for tav unet + # except: + # model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + # try: + # model_output = model(x_t, t, **model_kwargs).sample # for tav unet + # except: + # model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, F, C = x_t.shape[:3] + assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..2dfe4d99d24e9bbd5b4e507e9f1c5939fbb44055 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py @@ -0,0 +1,904 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion_T: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + #B, F, C = x.shape[:3] + B, C, F = x.shape[:3] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + + try: + model_output.shape + except: + model_output = model_output[0] + # try: + # model_output = model_output.sample # for tav unet + # except: + # model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + #assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + #model_output, model_var_values = th.split(model_output, C, dim=2) + #the output shape of uncondition or class condition latte is not the same as the latte_t2v + #BFCHW vs BCFHW + assert model_output.shape == (B, C * 2, F, *x.shape[3:]), f'model_output.shape ({model_output.shape}), != {(B, C * 2, F, *x.shape[3:])}' + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None, mask=1.0, + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + # import ipdb;ipdb.set_trace() + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl * mask) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll * mask) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + mask = 1.0 + else: + mask = model_kwargs['attention_mask'].unsqueeze(1) # b t h w -> b 1 t h w + + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + # try: + # model_output = model(x_t, t, **model_kwargs).sample # for tav unet + # except: + # model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + #B, F, C = x_t.shape[:3] + #assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + #the output shape of uncondition or class condition latte is not the same as the latte_t2v + #BFCHW vs BCFHW + B, C, F = x_t.shape[:3] + assert model_output[0].shape == (B, C * 2, F, *x_t.shape[3:]) + #model_output, model_var_values = th.split(model_output, C, dim=2) + model_output, model_var_values = th.split(model_output[0], C, dim=1) + + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + #frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + mask=mask + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat(((target - model_output) ** 2) * mask) + # import ipdb;ipdb.set_trace() + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/respace.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..aed6ed77f3dd6d38f15e450058ce7fc13d5dc3dc --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/respace.py @@ -0,0 +1,198 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import torch +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion +from .gaussian_diffusion_t2v import GaussianDiffusion_T + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + +class SpacedDiffusion_T(GaussianDiffusion_T): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/timestep_sampler.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/latte/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/latte/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/latte/modeling_latte.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/latte/modeling_latte.py new file mode 100644 index 0000000000000000000000000000000000000000..578e44abeaa79417cd2a78788a4dab9ab9447b50 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/latte/modeling_latte.py @@ -0,0 +1,679 @@ +import torch + +import os +import json + +from dataclasses import dataclass +from einops import rearrange, repeat +from typing import Any, Dict, Optional, Tuple +from diffusers.models import Transformer2DModel +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate +from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, ImagePositionalEmbeddings +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear + +import torch +import torch.nn.functional as F +from torch import nn + +# from opensora_plan.models.diffusion.utils.pos_embed import get_1d_sincos_pos_embed, PositionGetter1D, PositionGetter2D +# from opensora_plan.models.diffusion.latte.modules import PatchEmbed, BasicTransformerBlock, BasicTransformerBlock_, AdaLayerNormSingle, \ +# Transformer3DModelOutput, CaptionProjection +from ..utils.pos_embed import get_1d_sincos_pos_embed, PositionGetter1D, PositionGetter2D +from .modules import PatchEmbed, BasicTransformerBlock, BasicTransformerBlock_, AdaLayerNormSingle, \ + Transformer3DModelOutput, CaptionProjection + + +class LatteT2V(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + patch_size_t: int = 1, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + video_length: int = 16, + attention_mode: str = 'flash', + use_rope: bool = False, + model_max_length: int = 300, + rope_scaling_type: str = 'linear', + compress_kv_factor: int = 1, + interpolation_scale_1d: float = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.video_length = video_length + self.use_rope = use_rope + self.model_max_length = model_max_length + self.compress_kv_factor = compress_kv_factor + self.num_layers = num_layers + self.config.hidden_size = model_max_length + + assert not (self.compress_kv_factor != 1 and use_rope), "Can not both enable compressing kv and using rope" + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + # self.is_input_patches = in_channels is not None and patch_size is not None + self.is_input_patches = True + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # 2. Define input layers + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size[0] + self.width = sample_size[1] + + self.patch_size = patch_size + interpolation_scale_2d = self.config.sample_size[0] // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale_2d = max(interpolation_scale_2d, 1) + self.pos_embed = PatchEmbed( + height=sample_size[0], + width=sample_size[1], + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale_2d, + ) + + + # define temporal positional embedding + if interpolation_scale_1d is None: + if self.config.video_length % 2 == 1: + interpolation_scale_1d = (self.config.video_length - 1) // 16 # => 16 (= 16 Latte) has interpolation scale 1 + else: + interpolation_scale_1d = self.config.video_length // 16 # => 16 (= 16 Latte) has interpolation scale 1 + # interpolation_scale_1d = self.config.video_length // 5 # + interpolation_scale_1d = max(interpolation_scale_1d, 1) + temp_pos_embed = get_1d_sincos_pos_embed(inner_dim, video_length, interpolation_scale=interpolation_scale_1d) # 1152 hidden size + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + rope_scaling = None + if self.use_rope: + self.position_getter_2d = PositionGetter2D() + self.position_getter_1d = PositionGetter1D() + rope_scaling = dict(type=rope_scaling_type, factor_2d=interpolation_scale_2d, factor_1d=interpolation_scale_1d) + + # 3. Define transformers blocks, spatial attention + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=(compress_kv_factor, compress_kv_factor) if d >= num_layers // 2 and compress_kv_factor != 1 else None, # follow pixart-sigma, apply in second-half layers + ) + for d in range(num_layers) + ] + ) + + # Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock_( # one attention + inner_dim, + num_attention_heads, # num_attention_heads + attention_head_dim, # attention_head_dim 72 + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=False, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=(compress_kv_factor, ) if d >= num_layers // 2 and compress_kv_factor != 1 else None, # follow pixart-sigma, apply in second-half layers + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim ** 0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def make_position(self, b, t, use_image_num, h, w, device): + pos_hw = self.position_getter_2d(b*(t+use_image_num), h, w, device) # fake_b = b*(t+use_image_num) + pos_t = self.position_getter_1d(b*h*w, t, device) # fake_b = b*h*w + return pos_hw, pos_t + + def make_attn_mask(self, attention_mask, frame, dtype): + attention_mask = rearrange(attention_mask, 'b t h w -> (b t) 1 (h w)') + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(dtype)) * -10000.0 + attention_mask = attention_mask.to(self.dtype) + return attention_mask + + def vae_to_diff_mask(self, attention_mask, use_image_num): + dtype = attention_mask.dtype + # b, t+use_image_num, h, w, assume t as channel + # this version do not use 3d patch embedding + attention_mask = F.max_pool2d(attention_mask, kernel_size=(self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size)) + attention_mask = attention_mask.bool().to(dtype) + return attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + input_batch_size, c, frame, h, w = hidden_states.shape + frame = frame - use_image_num # 20-4=16 + hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w').contiguous() + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is None: + attention_mask = torch.ones((input_batch_size, frame+use_image_num, h, w), device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num) + dtype = attention_mask.dtype + attention_mask_compress = F.max_pool2d(attention_mask.float(), kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor) + attention_mask_compress = attention_mask_compress.to(dtype) + + attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype) + attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype) + + # 1 + 4, 1 -> video condition, 4 -> image condition + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + encoder_attention_mask = repeat(encoder_attention_mask, 'b 1 l -> (b f) 1 l', f=frame).contiguous() + encoder_attention_mask = encoder_attention_mask.to(self.dtype) + elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask_video = encoder_attention_mask[:, :1, ...] + encoder_attention_mask_video = repeat(encoder_attention_mask_video, 'b 1 l -> b (1 f) l', + f=frame).contiguous() + encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...] + encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1) + encoder_attention_mask = rearrange(encoder_attention_mask, 'b n l -> (b n) l').contiguous().unsqueeze(1) + encoder_attention_mask = encoder_attention_mask.to(self.dtype) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_patches: # here + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hw = (height, width) + num_patches = height * width + + hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + # batch_size = hidden_states.shape[0] + batch_size = input_batch_size + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152 + + if use_image_num != 0 and self.training: + encoder_hidden_states_video = encoder_hidden_states[:, :1, ...] + encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b 1 t d -> b (1 f) t d', f=frame).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + encoder_hidden_states_spatial = rearrange(encoder_hidden_states, 'b f t d -> (b f) t d').contiguous() + else: + encoder_hidden_states_spatial = repeat(encoder_hidden_states, 'b 1 t d -> (b f) t d', f=frame).contiguous() + + # prepare timesteps for spatial and temporal block + timestep_spatial = repeat(timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + timestep_temp = repeat(timestep, 'b d -> (b p) d', p=num_patches).contiguous() + + pos_hw, pos_t = None, None + if self.use_rope: + pos_hw, pos_t = self.make_position(input_batch_size, frame, use_image_num, height, width, hidden_states.device) + + for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): + + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + spatial_block, + hidden_states, + attention_mask_compress if i >= self.num_layers // 2 else attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + pos_hw, + pos_hw, + hw, + use_reentrant=False, + ) + + if enable_temporal_attentions: + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0: # image-video joitn training + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + # if i == 0 and not self.use_rope: + if i == 0: + hidden_states_video = hidden_states_video + self.temp_pos_embed + + hidden_states_video = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + pos_t, + pos_t, + (frame, ), + use_reentrant=False, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + + else: + # if i == 0 and not self.use_rope: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + pos_t, + pos_t, + (frame, ), + use_reentrant=False, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + else: + hidden_states = spatial_block( + hidden_states, + attention_mask_compress if i >= self.num_layers // 2 else attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + pos_hw, + pos_hw, + hw, + ) + + if enable_temporal_attentions: + # b c f h w, f = 16 + 4 + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0 and self.training: + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + # if i == 0 and not self.use_rope: + # hidden_states_video = hidden_states_video + self.temp_pos_embed + + hidden_states_video = temp_block( + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + pos_t, + pos_t, + (frame, ), + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + + else: + # if i == 0 and not self.use_rope: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + pos_t, + pos_t, + (frame, ), + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + output = rearrange(output, '(b f) c h w -> b c f h w', b=input_batch_size).contiguous() + + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + model = cls.from_config(config, **kwargs) + return model + +# depth = num_layers * 2 +def LatteT2V_XL_122(**kwargs): + return LatteT2V(num_layers=28, attention_head_dim=72, num_attention_heads=16, patch_size_t=1, patch_size=2, + norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1152, **kwargs) +def LatteT2V_D64_XL_122(**kwargs): + return LatteT2V(num_layers=28, attention_head_dim=64, num_attention_heads=18, patch_size_t=1, patch_size=2, + norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1152, **kwargs) + +Latte_models = { + "LatteT2V-XL/122": LatteT2V_XL_122, + "LatteT2V-D64-XL/122": LatteT2V_D64_XL_122, +} + +if __name__ == '__main__': + from opensora.models.ae import ae_channel_config, ae_stride_config + from opensora.models.ae import getae, getae_wrapper + from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper + + args = type('args', (), + { + 'ae': 'CausalVAEModel_4x8x8', + 'attention_mode': 'xformers', + 'use_rope': False, + 'model_max_length': 300, + 'max_image_size': 512, + 'num_frames': 65, + 'use_image_num': 16, + 'compress_kv_factor': 1 + } + ) + b = 2 + c = 4 + cond_c = 4096 + num_timesteps = 1000 + ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] + latent_size = (args.max_image_size // ae_stride_h, args.max_image_size // ae_stride_w) + if getae_wrapper(args.ae) == CausalVQVAEModelWrapper or getae_wrapper(args.ae) == CausalVAEModelWrapper: + args.video_length = video_length = (args.num_frames - 1) // ae_stride_t + 1 + else: + video_length = args.num_frames // ae_stride_t + + device = torch.device('cuda:6') + model = LatteT2V_D64_XL_122( + in_channels=ae_channel_config[args.ae], + out_channels=ae_channel_config[args.ae] * 2, + # caption_channels=4096, + # cross_attention_dim=1152, + attention_bias=True, + sample_size=latent_size, + num_vector_embeds=None, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + use_linear_projection=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + # norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + attention_type='default', + video_length=video_length, + attention_mode=args.attention_mode, + compress_kv_factor=args.compress_kv_factor, + use_rope=args.use_rope, + model_max_length=args.model_max_length, + ).to(device) + # try: + # ckpt = torch.load(r"t2v.pt", map_location='cpu')['model'] + # model.load_state_dict(ckpt) + # except Exception as e: + # print(e) + print(model) + + x = torch.randn(b, c, 1+(args.num_frames-1)//ae_stride_t+args.use_image_num, args.max_image_size//ae_stride_h, args.max_image_size//ae_stride_w).to(device) + cond = torch.randn(b, 1+args.use_image_num, args.model_max_length, cond_c).to(device) + attn_mask = torch.randint(0, 2, (b, 1+args.use_image_num, args.max_image_size//ae_stride_h//2, args.max_image_size//ae_stride_w//2)).to(device) # B L or B 1+num_images L + cond_mask = torch.randint(0, 2, (b, 1+args.use_image_num, args.model_max_length)).to(device) # B L or B 1+num_images L + timestep = torch.randint(0, 1000, (b,), device=device) + model_kwargs = dict(hidden_states=x, encoder_hidden_states=cond, attention_mask=attn_mask, + encoder_attention_mask=cond_mask, use_image_num=args.use_image_num, timestep=timestep) + with torch.no_grad(): + output = model(**model_kwargs) + # print(output) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/latte/modules.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/latte/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..3b0e5c4fa3135a591a7a6b34e90a641c585ae3cd --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/latte/modules.py @@ -0,0 +1,1729 @@ +from importlib import import_module + +import numpy as np +from typing import Any, Dict, Optional, Tuple, Callable +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear + +import torch +import torch.nn.functional as F +from torch import nn +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero +from diffusers.models.attention_processor import SpatialNorm, LORA_ATTENTION_PROCESSORS, \ + CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, \ + AttnAddedKVProcessor, AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, XFormersAttnAddedKVProcessor, \ + LoRAAttnAddedKVProcessor, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0, LoRAAttnProcessor, \ + AttnProcessor, SlicedAttnProcessor, logger +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU + +from dataclasses import dataclass + +from ..utils.pos_embed import get_2d_sincos_pos_embed, RoPE1D, RoPE2D, LinearScalingRoPE2D, LinearScalingRoPE1D + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class CombinedTimestepSizeEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.use_additional_conditions = True + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): + if size.ndim == 1: + size = size[:, None] + + if size.shape[0] != batch_size: + size = size.repeat(batch_size // size.shape[0], 1) + if size.shape[0] != batch_size: + raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") + + current_batch_size, dims = size.shape[0], size.shape[1] + size = size.reshape(-1) + size_freq = self.additional_condition_proj(size).to(size.dtype) + + size_emb = embedder(size_freq) + size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) + return size_emb + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) + aspect_ratio = self.apply_condition( + aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder + ) + conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, num_tokens=120): + super().__init__() + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + self.act_1 = nn.GELU(approximate="tanh") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) + self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5)) + + def forward(self, caption, force_drop_ids=None): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size = patch_size + # See: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 + self.height, self.width = height // patch_size, width // patch_size + + self.base_size = height // patch_size + self.interpolation_scale = interpolation_scale + pos_embed = get_2d_sincos_pos_embed( + embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + # Interpolate positional embeddings if needed. + # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) + if self.height != height or self.width != width: + # raise ValueError + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + return (latent + pos_embed).to(latent.dtype) + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + attention_mode: str = 'xformers', + use_rope: bool = False, + rope_scaling: Optional[Dict] = None, + compress_kv_factor: Optional[Tuple] = None, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.use_rope = use_rope + self.rope_scaling = rope_scaling + self.compress_kv_factor = compress_kv_factor + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + if USE_PEFT_BACKEND: + linear_cls = nn.Linear + else: + linear_cls = LoRACompatibleLinear + + assert not (self.use_rope and (self.compress_kv_factor is not None)), "Can not both enable compressing kv and using rope" + if self.compress_kv_factor is not None: + self._init_compress() + + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0(self.inner_dim, attention_mode, use_rope, rope_scaling=rope_scaling, compress_kv_factor=compress_kv_factor) if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_lora = hasattr(self, "processor") and isinstance( + self.processor, + LORA_ATTENTION_PROCESSORS, + ) + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and (is_lora or is_custom_diffusion): + raise NotImplementedError( + f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_lora: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + _remove_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to remove LoRA layers from the model. + """ + if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: + deprecate( + "set_processor to offload LoRA", + "0.26.0", + "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", + ) + # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete + # We need to remove all LoRA layers + # Don't forget to remove ALL `_remove_lora` from the codebase + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) + lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + def _init_compress(self): + if len(self.compress_kv_factor) == 2: + self.sr = nn.Conv2d(self.inner_dim, self.inner_dim, groups=self.inner_dim, kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor) + self.sr.weight.data.fill_(1/self.compress_kv_factor[0]**2) + elif len(self.compress_kv_factor) == 1: + self.kernel_size = self.compress_kv_factor[0] + self.sr = nn.Conv1d(self.inner_dim, self.inner_dim, groups=self.inner_dim, kernel_size=self.compress_kv_factor[0], stride=self.compress_kv_factor[0]) + self.sr.weight.data.fill_(1/self.compress_kv_factor[0]) + self.sr.bias.data.zero_() + self.norm = nn.LayerNorm(self.inner_dim) + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, dim=1152, attention_mode='xformers', use_rope=False, rope_scaling=None, compress_kv_factor=None): + self.dim = dim + self.attention_mode = attention_mode + self.use_rope = use_rope + self.rope_scaling = rope_scaling + self.compress_kv_factor = compress_kv_factor + if self.use_rope: + self._init_rope() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + + def _init_rope(self): + if self.rope_scaling is None: + self.rope2d = RoPE2D() + self.rope1d = RoPE1D() + else: + scaling_type = self.rope_scaling["type"] + scaling_factor_2d = self.rope_scaling["factor_2d"] + scaling_factor_1d = self.rope_scaling["factor_1d"] + if scaling_type == "linear": + self.rope2d = LinearScalingRoPE2D(scaling_factor=scaling_factor_2d) + self.rope1d = LinearScalingRoPE1D(scaling_factor=scaling_factor_1d) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + position_q: Optional[torch.LongTensor] = None, + position_k: Optional[torch.LongTensor] = None, + last_shape: Tuple[int] = None, + ) -> torch.FloatTensor: + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + + + if self.compress_kv_factor is not None: + batch_size = hidden_states.shape[0] + if len(last_shape) == 2: + encoder_hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, self.dim, *last_shape) + encoder_hidden_states = attn.sr(encoder_hidden_states).reshape(batch_size, self.dim, -1).permute(0, 2, 1) + elif len(last_shape) == 1: + encoder_hidden_states = hidden_states.permute(0, 2, 1) + if last_shape[0] % 2 == 1: + first_frame_pad = encoder_hidden_states[:, :, :1].repeat((1, 1, attn.kernel_size - 1)) + encoder_hidden_states = torch.concatenate((first_frame_pad, encoder_hidden_states), dim=2) + encoder_hidden_states = attn.sr(encoder_hidden_states).permute(0, 2, 1) + else: + raise NotImplementedError(f'NotImplementedError with last_shape {last_shape}') + + encoder_hidden_states = attn.norm(encoder_hidden_states) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if self.use_rope: + # require the shape of (batch_size x nheads x ntokens x dim) + if position_q.ndim == 3: + query = self.rope2d(query, position_q) + elif position_q.ndim == 2: + query = self.rope1d(query, position_q) + else: + raise NotImplementedError + if position_k.ndim == 3: + key = self.rope2d(key, position_k) + elif position_k.ndim == 2: + key = self.rope1d(key, position_k) + else: + raise NotImplementedError + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + if self.attention_mode == 'flash': + assert attention_mask is None or torch.all(attention_mask.bool()), 'flash-attn do not support attention_mask' + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + elif self.attention_mode == 'xformers': + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True): + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + elif self.attention_mode == 'math': + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + else: + raise NotImplementedError(f'Found attention_mode: {self.attention_mode}') + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock_(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + attention_mode: str = "xformers", + use_rope: bool = False, + rope_scaling: Optional[Dict] = None, + compress_kv_factor: Optional[Tuple] = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=compress_kv_factor, + ) + + # # 2. Cross-Attn + # if cross_attention_dim is not None or double_self_attention: + # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # # the second cross attention block. + # self.norm2 = ( + # AdaLayerNorm(dim, num_embeds_ada_norm) + # if self.use_ada_layer_norm + # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + # ) + # self.attn2 = Attention( + # query_dim=dim, + # cross_attention_dim=cross_attention_dim if not double_self_attention else None, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # upcast_attention=upcast_attention, + # ) # is self-attn if encoder_hidden_states is none + # else: + # self.norm2 = None + # self.attn2 = None + + # 3. Feed-forward + # if not self.use_ada_layer_norm_single: + # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + position_q: Optional[torch.LongTensor] = None, + position_k: Optional[torch.LongTensor] = None, + frame: int = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + position_q=position_q, + position_k=position_k, + last_shape=frame, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # # 3. Cross-Attention + # if self.attn2 is not None: + # if self.use_ada_layer_norm: + # norm_hidden_states = self.norm2(hidden_states, timestep) + # elif self.use_ada_layer_norm_zero or self.use_layer_norm: + # norm_hidden_states = self.norm2(hidden_states) + # elif self.use_ada_layer_norm_single: + # # For PixArt norm2 isn't applied here: + # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + # norm_hidden_states = hidden_states + # else: + # raise ValueError("Incorrect norm") + + # if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + # norm_hidden_states = self.pos_embed(norm_hidden_states) + + # attn_output = self.attn2( + # norm_hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # attention_mask=encoder_attention_mask, + # **cross_attention_kwargs, + # ) + # hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # if not self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = self.norm3(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + attention_mode: str = "xformers", + use_rope: bool = False, + rope_scaling: Optional[Dict] = None, + compress_kv_factor: Optional[Tuple] = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=compress_kv_factor, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + attention_mode=attention_mode, # only xformers support attention_mask + use_rope=False, # do not position in cross attention + compress_kv_factor=None, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + position_q: Optional[torch.LongTensor] = None, + position_k: Optional[torch.LongTensor] = None, + hw: Tuple[int, int] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + position_q=position_q, + position_k=position_k, + last_shape=hw, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_q=None, # cross attn do not need relative position + position_k=None, + last_shape=None, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = CombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + batch_size: int = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, + aspect_ratio=None) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db68edd20c9716e74ef1c853e968227efe45be29 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/__init__.py @@ -0,0 +1,63 @@ +from .transport import Transport, ModelType, WeightType, PathType, Sampler + +def create_transport( + path_type='Linear', + prediction="velocity", + loss_weight=None, + train_eps=None, + sample_eps=None, +): + """function for creating Transport object + **Note**: model prediction defaults to velocity + Args: + - path_type: type of path to use; default to linear + - learn_score: set model prediction to score + - learn_noise: set model prediction to noise + - velocity_weighted: weight loss by velocity weight + - likelihood_weighted: weight loss by likelihood weight + - train_eps: small epsilon for avoiding instability during training + - sample_eps: small epsilon for avoiding instability during sampling + """ + + if prediction == "noise": + model_type = ModelType.NOISE + elif prediction == "score": + model_type = ModelType.SCORE + else: + model_type = ModelType.VELOCITY + + if loss_weight == "velocity": + loss_type = WeightType.VELOCITY + elif loss_weight == "likelihood": + loss_type = WeightType.LIKELIHOOD + else: + loss_type = WeightType.NONE + + path_choice = { + "Linear": PathType.LINEAR, + "GVP": PathType.GVP, + "VP": PathType.VP, + } + + path_type = path_choice[path_type] + + if (path_type in [PathType.VP]): + train_eps = 1e-5 if train_eps is None else train_eps + sample_eps = 1e-3 if train_eps is None else sample_eps + elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): + train_eps = 1e-3 if train_eps is None else train_eps + sample_eps = 1e-3 if train_eps is None else sample_eps + else: # velocity & [GVP, LINEAR] is stable everywhere + train_eps = 0 + sample_eps = 0 + + # create flow state + state = Transport( + model_type=model_type, + path_type=path_type, + loss_type=loss_type, + train_eps=train_eps, + sample_eps=sample_eps, + ) + + return state \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/integrators.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/integrators.py new file mode 100644 index 0000000000000000000000000000000000000000..adf7c7b4c50b6ff6c63973e0ddaa65b9759274c0 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/integrators.py @@ -0,0 +1,117 @@ +import numpy as np +import torch as th +import torch.nn as nn +from torchdiffeq import odeint +from functools import partial +from tqdm import tqdm + +class sde: + """SDE solver class""" + def __init__( + self, + drift, + diffusion, + *, + t0, + t1, + num_steps, + sampler_type, + ): + assert t0 < t1, "SDE sampler has to be in forward time" + + self.num_timesteps = num_steps + self.t = th.linspace(t0, t1, num_steps) + self.dt = self.t[1] - self.t[0] + self.drift = drift + self.diffusion = diffusion + self.sampler_type = sampler_type + + def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + t = th.ones(x.size(0)).to(x) * t + dw = w_cur * th.sqrt(self.dt) + drift = self.drift(x, t, model, **model_kwargs) + diffusion = self.diffusion(x, t) + mean_x = x + drift * self.dt + x = mean_x + th.sqrt(2 * diffusion) * dw + return x, mean_x + + def __Heun_step(self, x, _, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + dw = w_cur * th.sqrt(self.dt) + t_cur = th.ones(x.size(0)).to(x) * t + diffusion = self.diffusion(x, t_cur) + xhat = x + th.sqrt(2 * diffusion) * dw + K1 = self.drift(xhat, t_cur, model, **model_kwargs) + xp = xhat + self.dt * K1 + K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) + return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step + + def __forward_fn(self): + """TODO: generalize here by adding all private functions ending with steps to it""" + sampler_dict = { + "Euler": self.__Euler_Maruyama_step, + "Heun": self.__Heun_step, + } + + try: + sampler = sampler_dict[self.sampler_type] + except: + raise NotImplementedError("Smapler type not implemented.") + + return sampler + + def sample(self, init, model, **model_kwargs): + """forward loop of sde""" + x = init + mean_x = init + samples = [] + sampler = self.__forward_fn() + for ti in self.t[:-1]: + with th.no_grad(): + x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) + samples.append(x) + + return samples + +class ode: + """ODE solver class""" + def __init__( + self, + drift, + *, + t0, + t1, + sampler_type, + num_steps, + atol, + rtol, + ): + assert t0 < t1, "ODE sampler has to be in forward time" + + self.drift = drift + self.t = th.linspace(t0, t1, num_steps) + self.atol = atol + self.rtol = rtol + self.sampler_type = sampler_type + + def sample(self, x, model, **model_kwargs): + + device = x[0].device if isinstance(x, tuple) else x.device + def _fn(t, x): + t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t + model_output = self.drift(x, t, model, **model_kwargs) + return model_output + + t = self.t.to(device) + atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] + rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] + samples = odeint( + _fn, + x, + t, + method=self.sampler_type, + atol=atol, + rtol=rtol + ) + return samples \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/path.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/path.py new file mode 100644 index 0000000000000000000000000000000000000000..156a7b0dea03497a85306ebbeedfe4fbedf87c27 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/path.py @@ -0,0 +1,192 @@ +import torch as th +import numpy as np +from functools import partial + +def expand_t_like_x(t, x): + """Function to reshape time t to broadcastable dimension of x + Args: + t: [batch_dim,], time vector + x: [batch_dim,...], data point + """ + dims = [1] * (len(x.size()) - 1) + t = t.view(t.size(0), *dims) + return t + + +#################### Coupling Plans #################### + +class ICPlan: + """Linear Coupling Plan""" + def __init__(self, sigma=0.0): + self.sigma = sigma + + def compute_alpha_t(self, t): + """Compute the data coefficient along the path""" + return t, 1 + + def compute_sigma_t(self, t): + """Compute the noise coefficient along the path""" + return 1 - t, -1 + + def compute_d_alpha_alpha_ratio_t(self, t): + """Compute the ratio between d_alpha and alpha""" + return 1 / t + + def compute_drift(self, x, t): + """We always output sde according to score parametrization; """ + t = expand_t_like_x(t, x) + alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + drift = alpha_ratio * x + diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t + + return -drift, diffusion + + def compute_diffusion(self, x, t, form="constant", norm=1.0): + """Compute the diffusion term of the SDE + Args: + x: [batch_dim, ...], data point + t: [batch_dim,], time vector + form: str, form of the diffusion term + norm: float, norm of the diffusion term + """ + t = expand_t_like_x(t, x) + choices = { + "constant": norm, + "SBDM": norm * self.compute_drift(x, t)[1], + "sigma": norm * self.compute_sigma_t(t)[0], + "linear": norm * (1 - t), + "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, + "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, + } + + try: + diffusion = choices[form] + except KeyError: + raise NotImplementedError(f"Diffusion form {form} not implemented") + + return diffusion + + def get_score_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to score + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t + score = (reverse_alpha_ratio * velocity - mean) / var + return score + + def get_noise_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to denoiser + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = reverse_alpha_ratio * d_sigma_t - sigma_t + noise = (reverse_alpha_ratio * velocity - mean) / var + return noise + + def get_velocity_from_score(self, score, x, t): + """Wrapper function: transfrom score prediction model to velocity + Args: + score: [batch_dim, ...] shaped tensor; score model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + drift, var = self.compute_drift(x, t) + velocity = var * score - drift + return velocity + + def compute_mu_t(self, t, x0, x1): + """Compute the mean of time-dependent density p_t""" + t = expand_t_like_x(t, x1) + alpha_t, _ = self.compute_alpha_t(t) + sigma_t, _ = self.compute_sigma_t(t) + return alpha_t * x1 + sigma_t * x0 + + def compute_xt(self, t, x0, x1): + """Sample xt from time-dependent density p_t; rng is required""" + xt = self.compute_mu_t(t, x0, x1) + return xt + + def compute_ut(self, t, x0, x1, xt): + """Compute the vector field corresponding to p_t""" + t = expand_t_like_x(t, x1) + _, d_alpha_t = self.compute_alpha_t(t) + _, d_sigma_t = self.compute_sigma_t(t) + return d_alpha_t * x1 + d_sigma_t * x0 + + def plan(self, t, x0, x1): + xt = self.compute_xt(t, x0, x1) + ut = self.compute_ut(t, x0, x1, xt) + return t, xt, ut + + +class VPCPlan(ICPlan): + """class for VP path flow matching""" + + def __init__(self, sigma_min=0.1, sigma_max=20.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min + self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min + + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = self.log_mean_coeff(t) + alpha_t = th.exp(alpha_t) + d_alpha_t = alpha_t * self.d_log_mean_coeff(t) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + p_sigma_t = 2 * self.log_mean_coeff(t) + sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) + d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return self.d_log_mean_coeff(t) + + def compute_drift(self, x, t): + """Compute the drift term of the SDE""" + t = expand_t_like_x(t, x) + beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) + return -0.5 * beta_t * x, beta_t / 2 + + +class GVPCPlan(ICPlan): + def __init__(self, sigma=0.0): + super().__init__(sigma) + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = th.sin(t * np.pi / 2) + d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + sigma_t = th.cos(t * np.pi / 2) + d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return np.pi / (2 * th.tan(t * np.pi / 2)) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/transport.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/transport.py new file mode 100644 index 0000000000000000000000000000000000000000..396c516cfc64516a39212d95ff895c98135eef17 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/transport.py @@ -0,0 +1,443 @@ +import torch as th +import numpy as np +import logging + +import enum + +from . import path +from .utils import EasyDict, log_state, mean_flat +from .integrators import ode, sde + +class ModelType(enum.Enum): + """ + Which type of output the model predicts. + """ + + NOISE = enum.auto() # the model predicts epsilon + SCORE = enum.auto() # the model predicts \nabla \log p(x) + VELOCITY = enum.auto() # the model predicts v(x) + +class PathType(enum.Enum): + """ + Which type of path to use. + """ + + LINEAR = enum.auto() + GVP = enum.auto() + VP = enum.auto() + +class WeightType(enum.Enum): + """ + Which type of weighting to use. + """ + + NONE = enum.auto() + VELOCITY = enum.auto() + LIKELIHOOD = enum.auto() + + +class Transport: + + def __init__( + self, + *, + model_type, + path_type, + loss_type, + train_eps, + sample_eps, + ): + path_options = { + PathType.LINEAR: path.ICPlan, + PathType.GVP: path.GVPCPlan, + PathType.VP: path.VPCPlan, + } + + self.loss_type = loss_type + self.model_type = model_type + self.path_sampler = path_options[path_type]() + self.train_eps = train_eps + self.sample_eps = sample_eps + + def prior_logp(self, z): + ''' + Standard multivariate normal prior + Assume z is batched + ''' + shape = th.tensor(z.size()) + N = th.prod(shape[1:]) + _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2. + return th.vmap(_fn)(z) + + + def check_interval( + self, + train_eps, + sample_eps, + *, + diffusion_form="SBDM", + sde=False, + reverse=False, + eval=False, + last_step_size=0.0, + ): + t0 = 0 + t1 = 1 + eps = train_eps if not eval else sample_eps + if (type(self.path_sampler) in [path.VPCPlan]): + + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \ + and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step + + t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + if reverse: + t0, t1 = 1 - t0, 1 - t1 + + return t0, t1 + + + def sample(self, x1): + """Sampling x0 & t based on shape of x1 (if needed) + Args: + x1 - data point; [batch, *dim] + """ + + x0 = th.randn_like(x1) + t0, t1 = self.check_interval(self.train_eps, self.sample_eps) + t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 + t = t.to(x1) + return t, x0, x1 + + + def training_losses( + self, + model, + x1, + model_kwargs=None + ): + """Loss for training the score model + Args: + - model: backbone model; could be score, noise, or velocity + - x1: datapoint + - model_kwargs: additional arguments for the model + """ + if model_kwargs == None: + model_kwargs = {} + + t, x0, x1 = self.sample(x1) + t, xt, ut = self.path_sampler.plan(t, x0, x1) + model_output = model(xt, t, **model_kwargs) + B, *_, C = xt.shape + assert model_output.size() == (B, *xt.size()[1:-1], C) + + terms = {} + terms['pred'] = model_output + if self.model_type == ModelType.VELOCITY: + terms['loss'] = mean_flat(((model_output - ut) ** 2)) + else: + _, drift_var = self.path_sampler.compute_drift(xt, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) + if self.loss_type in [WeightType.VELOCITY]: + weight = (drift_var / sigma_t) ** 2 + elif self.loss_type in [WeightType.LIKELIHOOD]: + weight = drift_var / (sigma_t ** 2) + elif self.loss_type in [WeightType.NONE]: + weight = 1 + else: + raise NotImplementedError() + + if self.model_type == ModelType.NOISE: + terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2)) + else: + terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) + + return terms + + + def get_drift( + self + ): + """member function for obtaining the drift of the probability flow ODE""" + def score_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + model_output = model(x, t, **model_kwargs) + return (-drift_mean + drift_var * model_output) # by change of variable + + def noise_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) + model_output = model(x, t, **model_kwargs) + score = model_output / -sigma_t + return (-drift_mean + drift_var * score) + + def velocity_ode(x, t, model, **model_kwargs): + model_output = model(x, t, **model_kwargs) + return model_output + + if self.model_type == ModelType.NOISE: + drift_fn = noise_ode + elif self.model_type == ModelType.SCORE: + drift_fn = score_ode + else: + drift_fn = velocity_ode + + def body_fn(x, t, model, **model_kwargs): + model_output = drift_fn(x, t, model, **model_kwargs) + assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" + return model_output + + return body_fn + + + def get_score( + self, + ): + """member function for obtaining score of + x_t = alpha_t * x + sigma_t * eps""" + if self.model_type == ModelType.NOISE: + score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] + elif self.model_type == ModelType.SCORE: + score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) + elif self.model_type == ModelType.VELOCITY: + score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t) + else: + raise NotImplementedError() + + return score_fn + + +class Sampler: + """Sampler class for the transport model""" + def __init__( + self, + transport, + ): + """Constructor for a general sampler; supporting different sampling methods + Args: + - transport: an tranport object specify model prediction & interpolant type + """ + + self.transport = transport + self.drift = self.transport.get_drift() + self.score = self.transport.get_score() + + def __get_sde_diffusion_and_drift( + self, + *, + diffusion_form="SBDM", + diffusion_norm=1.0, + ): + + def diffusion_fn(x, t): + diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) + return diffusion + + sde_drift = \ + lambda x, t, model, **kwargs: \ + self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) + + sde_diffusion = diffusion_fn + + return sde_drift, sde_diffusion + + def __get_last_step( + self, + sde_drift, + *, + last_step, + last_step_size, + ): + """Get the last step function of the SDE solver""" + + if last_step is None: + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + elif last_step == "Mean": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + sde_drift(x, t, model, **model_kwargs) * last_step_size + elif last_step == "Tweedie": + alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long + sigma = self.transport.path_sampler.compute_sigma_t + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) + elif last_step == "Euler": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + self.drift(x, t, model, **model_kwargs) * last_step_size + else: + raise NotImplementedError() + + return last_step_fn + + def sample_sde( + self, + *, + sampling_method="Euler", + diffusion_form="SBDM", + diffusion_norm=1.0, + last_step="Mean", + last_step_size=0.04, + num_steps=250, + ): + """returns a sampling function with given SDE settings + Args: + - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama + - diffusion_form: function form of diffusion coefficient; default to be matching SBDM + - diffusion_norm: function magnitude of diffusion coefficient; default to 1 + - last_step: type of the last step; default to identity + - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] + - num_steps: total integration step of SDE + """ + + if last_step is None: + last_step_size = 0.0 + + sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( + diffusion_form=diffusion_form, + diffusion_norm=diffusion_norm, + ) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + diffusion_form=diffusion_form, + sde=True, + eval=True, + reverse=False, + last_step_size=last_step_size, + ) + + _sde = sde( + sde_drift, + sde_diffusion, + t0=t0, + t1=t1, + num_steps=num_steps, + sampler_type=sampling_method + ) + + last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) + + + def _sample(init, model, **model_kwargs): + xs = _sde.sample(init, model, **model_kwargs) + ts = th.ones(init.size(0), device=init.device) * t1 + x = last_step_fn(xs[-1], ts, model, **model_kwargs) + xs.append(x) + + assert len(xs) == num_steps, "Samples does not match the number of steps" + + return xs + + return _sample + + def sample_ode( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + reverse=False, + ): + """returns a sampling function with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + - reverse: whether solving the ODE in reverse (data to noise); default to False + """ + if reverse: + drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) + else: + drift = self.drift + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=reverse, + last_step_size=0.0, + ) + + _ode = ode( + drift=drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + return _ode.sample + + def sample_ode_likelihood( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + ): + + """returns a sampling function for calculating likelihood with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + """ + def _likelihood_drift(x, t, model, **model_kwargs): + x, _ = x + eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 + t = th.ones_like(t) * (1 - t) + with th.enable_grad(): + x.requires_grad = True + grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] + logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) + drift = self.drift(x, t, model, **model_kwargs) + return (-drift, logp_grad) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=False, + last_step_size=0.0, + ) + + _ode = ode( + drift=_likelihood_drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + def _sample_fn(x, model, **model_kwargs): + init_logp = th.zeros(x.size(0)).to(x) + input = (x, init_logp) + drift, delta_logp = _ode.sample(input, model, **model_kwargs) + drift, delta_logp = drift[-1], delta_logp[-1] + prior_logp = self.transport.prior_logp(drift) + logp = prior_logp - delta_logp + return logp, drift + + return _sample_fn \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44646035531326b81883727f973900edb4eac494 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/transport/utils.py @@ -0,0 +1,29 @@ +import torch as th + +class EasyDict: + + def __init__(self, sub_dict): + for k, v in sub_dict.items(): + setattr(self, k, v) + + def __getitem__(self, key): + return getattr(self, key) + +def mean_flat(x): + """ + Take the mean over all non-batch dimensions. + """ + return th.mean(x, dim=list(range(1, len(x.size())))) + +def log_state(state): + result = [] + + sorted_state = dict(sorted(state.items())) + for key, value in sorted_state.items(): + # Check if the value is an instance of a class + if " + +// forward declaration +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); + +void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) +{ + const int B = tokens.size(0); + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3) / 4; + + auto tok = tokens.accessor(); + auto pos = positions.accessor(); + + for (int b = 0; b < B; b++) { + for (int x = 0; x < 2; x++) { // y and then x (2d) + for (int n = 0; n < N; n++) { + + // grab the token position + const int p = pos[b][n][x]; + + for (int h = 0; h < H; h++) { + for (int d = 0; d < D; d++) { + // grab the two values + float u = tok[b][n][h][d+0+x*2*D]; + float v = tok[b][n][h][d+D+x*2*D]; + + // grab the cos,sin + const float inv_freq = fwd * p / powf(base, d/float(D)); + float c = cosf(inv_freq); + float s = sinf(inv_freq); + + // write the result + tok[b][n][h][d+0+x*2*D] = u*c - v*s; + tok[b][n][h][d+D+x*2*D] = v*c + u*s; + } + } + } + } + } +} + +void rope_2d( torch::Tensor tokens, // B,N,H,D + const torch::Tensor positions, // B,N,2 + const float base, + const float fwd ) +{ + TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); + TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); + TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); + TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); + TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); + TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); + + if (tokens.is_cuda()) + rope_2d_cuda( tokens, positions, base, fwd ); + else + rope_2d_cpu( tokens, positions, base, fwd ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); +} diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/curope/curope2d.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/curope/curope2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a49c12f8c529e9a889b5ac20c5767158f238e17d --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/curope/curope2d.py @@ -0,0 +1,40 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch + +try: + import curope as _kernels # run `python setup.py install` +except ModuleNotFoundError: + from . import curope as _kernels # run `python setup.py build_ext --inplace` + + +class cuRoPE2D_func (torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, positions, base, F0=1): + ctx.save_for_backward(positions) + ctx.saved_base = base + ctx.saved_F0 = F0 + # tokens = tokens.clone() # uncomment this if inplace doesn't work + _kernels.rope_2d( tokens, positions, base, F0 ) + ctx.mark_dirty(tokens) + return tokens + + @staticmethod + def backward(ctx, grad_res): + positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 + _kernels.rope_2d( grad_res, positions, base, -F0 ) + ctx.mark_dirty(grad_res) + return grad_res, None, None, None + + +class cuRoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + + def forward(self, tokens, positions): + cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) + return tokens \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/curope/kernels.cu b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/curope/kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..7156cd1bb935cb1f0be45e58add53f9c21505c20 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/curope/kernels.cu @@ -0,0 +1,108 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include +#include +#include +#include + +#define CHECK_CUDA(tensor) {\ + TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ + TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } +void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} + + +template < typename scalar_t > +__global__ void rope_2d_cuda_kernel( + //scalar_t* __restrict__ tokens, + torch::PackedTensorAccessor32 tokens, + const int64_t* __restrict__ pos, + const float base, + const float fwd ) + // const int N, const int H, const int D ) +{ + // tokens shape = (B, N, H, D) + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3); + + // each block update a single token, for all heads + // each thread takes care of a single output + extern __shared__ float shared[]; + float* shared_inv_freq = shared + D; + + const int b = blockIdx.x / N; + const int n = blockIdx.x % N; + + const int Q = D / 4; + // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] + // u_Y v_Y u_X v_X + + // shared memory: first, compute inv_freq + if (threadIdx.x < Q) + shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); + __syncthreads(); + + // start of X or Y part + const int X = threadIdx.x < D/2 ? 0 : 1; + const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X + + // grab the cos,sin appropriate for me + const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; + const float cos = cosf(freq); + const float sin = sinf(freq); + /* + float* shared_cos_sin = shared + D + D/4; + if ((threadIdx.x % (D/2)) < Q) + shared_cos_sin[m+0] = cosf(freq); + else + shared_cos_sin[m+Q] = sinf(freq); + __syncthreads(); + const float cos = shared_cos_sin[m+0]; + const float sin = shared_cos_sin[m+Q]; + */ + + for (int h = 0; h < H; h++) + { + // then, load all the token for this head in shared memory + shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; + __syncthreads(); + + const float u = shared[m]; + const float v = shared[m+Q]; + + // write output + if ((threadIdx.x % (D/2)) < Q) + tokens[b][n][h][threadIdx.x] = u*cos - v*sin; + else + tokens[b][n][h][threadIdx.x] = v*cos + u*sin; + } +} + +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) +{ + const int B = tokens.size(0); // batch size + const int N = tokens.size(1); // sequence length + const int H = tokens.size(2); // number of heads + const int D = tokens.size(3); // dimension per head + + TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); + TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); + TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); + TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); + + // one block for each layer, one thread per local-max + const int THREADS_PER_BLOCK = D; + const int N_BLOCKS = B * N; // each block takes care of H*D values + const int SHARED_MEM = sizeof(float) * (D + D/4); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { + rope_2d_cuda_kernel <<>> ( + //tokens.data_ptr(), + tokens.packed_accessor32(), + pos.data_ptr(), + base, fwd); //, N, H, D ); + })); +} diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/curope/setup.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/curope/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..230632ed05e309200e8f93a3a852072333975009 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/curope/setup.py @@ -0,0 +1,34 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from setuptools import setup +from torch import cuda +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# compile for all possible CUDA architectures +all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() +# alternatively, you can list cuda archs that you want, eg: +# all_cuda_archs = [ + # '-gencode', 'arch=compute_70,code=sm_70', + # '-gencode', 'arch=compute_75,code=sm_75', + # '-gencode', 'arch=compute_80,code=sm_80', + # '-gencode', 'arch=compute_86,code=sm_86' +# ] + +setup( + name = 'curope', + ext_modules = [ + CUDAExtension( + name='curope', + sources=[ + "curope.cpp", + "kernels.cu", + ], + extra_compile_args = dict( + nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, + cxx=['-O3']) + ) + ], + cmdclass = { + 'build_ext': BuildExtension + }) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/pos_embed.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..61872bc97adfa6d3b490f98fd71c3cd15f5c8650 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/diffusion/utils/pos_embed.py @@ -0,0 +1,243 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# croco: https://github.com/naver/croco +# diffusers: https://github.com/huggingface/diffusers +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np +import torch + + +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed( + embed_dim, length, interpolation_scale=1.0, base_size=16 +): + pos = torch.arange(0, length).unsqueeze(1) / interpolation_scale + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + return pos_embed + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# ---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +# ---------------------------------------------------------- + +try: + from .curope import cuRoPE2D + + RoPE2D = cuRoPE2D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + + class RoPE2D(torch.nn.Module): + + def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.scaling_factor = scaling_factor + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3) % 2 == 0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) + x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens + +class LinearScalingRoPE2D(RoPE2D): + """Code from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148""" + + def forward(self, tokens, positions): + # difference to the original RoPE: a scaling factor is aplied to the position ids + dtype = positions.dtype + positions = positions.float() / self.scaling_factor + positions = positions.to(dtype) + tokens = super().forward(tokens, positions) + return tokens + + +try: + from .curope import cuRoPE1D + + RoPE1D = cuRoPE1D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + + class RoPE1D(torch.nn.Module): + + def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.scaling_factor = scaling_factor + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens (t position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + D = tokens.size(3) + assert positions.ndim == 2 # Batch, Seq + cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype) + tokens = self.apply_rope1d(tokens, positions, cos, sin) + return tokens + +class LinearScalingRoPE1D(RoPE1D): + """Code from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148""" + + def forward(self, tokens, positions): + # difference to the original RoPE: a scaling factor is aplied to the position ids + dtype = positions.dtype + positions = positions.float() / self.scaling_factor + positions = positions.to(dtype) + tokens = super().forward(tokens, positions) + return tokens + + +class PositionGetter2D(object): + """ return positions of patches """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() + return pos + + + +class PositionGetter1D(object): + """ return positions of patches """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, l, device): + if not (l) in self.cache_positions: + x = torch.arange(l, device=device) + self.cache_positions[l] = x # (l, ) + pos = self.cache_positions[l].view(1, l).expand(b, -1).clone() + return pos \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/cfgs/AMT-G.yaml b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/cfgs/AMT-G.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d259d4fe97336d2a3e9c9e34a0067ca5ddfae1f0 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/cfgs/AMT-G.yaml @@ -0,0 +1,9 @@ + +seed: 2023 + +network: + name: networks.AMT-G.Model + params: + corr_radius: 3 + corr_lvls: 4 + num_flows: 5 \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/cfgs/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/cfgs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/interpolation.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..0c60bed8a4dc7747448156a18d93efacaaa837e7 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/interpolation.py @@ -0,0 +1,197 @@ +# this script is modified from https://github.com/MCG-NKU/AMT/blob/main/demos/demo_2x.py +from json import load +import os +import cv2 +import sys +import glob +import torch +import argparse +import numpy as np +import os.path as osp +from warnings import warn +from omegaconf import OmegaConf +from torchvision.utils import make_grid +sys.path.append('.') +from utils.utils import ( + read, write, + img2tensor, tensor2img, + check_dim_and_resize + ) +from utils.build_utils import build_from_cfg +from utils.utils import InputPadder + + +AMT_G = { + 'name': 'networks.AMT-G.Model', + 'params':{ + 'corr_radius': 3, + 'corr_lvls': 4, + 'num_flows': 5, + } +} + + + +def init(device="cuda"): + + ''' + initialize the device and the anchor resolution. + ''' + + if device == 'cuda': + anchor_resolution = 1024 * 512 + anchor_memory = 1500 * 1024**2 + anchor_memory_bias = 2500 * 1024**2 + vram_avail = torch.cuda.get_device_properties(device).total_memory + print("VRAM available: {:.1f} MB".format(vram_avail / 1024 ** 2)) + else: + # Do not resize in cpu mode + anchor_resolution = 8192*8192 + anchor_memory = 1 + anchor_memory_bias = 0 + vram_avail = 1 + + return anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail + +def get_input_video_from_path(input_path, device="cuda"): + + ''' + Get the input video from the input_path. + + params: + input_path: str, the path of the input video. + devices: str, the device to run the model. + returns: + inputs: list, the list of the input frames. + scale: float, the scale of the input frames. + padder: InputPadder, the padder to pad the input frames. + ''' + + anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail = init(device) + + if osp.splitext(input_path)[-1] in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', + '.webm', '.MP4', '.AVI', '.MOV', '.MKV', '.FLV', + '.WMV', '.WEBM']: + + vcap = cv2.VideoCapture(input_path) + + inputs = [] + w = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) + scale = 1 if scale > 1 else scale + scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 + if scale < 1: + print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") + padding = int(16 / scale) + padder = InputPadder((h, w), padding) + while True: + ret, frame = vcap.read() + if ret is False: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_t = img2tensor(frame).to(device) + frame_t = padder.pad(frame_t) + inputs.append(frame_t) + print(f'Loading the [video] from {input_path}, the number of frames [{len(inputs)}]') + else: + raise TypeError("Input should be a video.") + + return inputs, scale, padder + + +def load_model(ckpt_path, device="cuda"): + + ''' + load the frame interpolation model. + ''' + network_cfg = AMT_G + network_name = network_cfg['name'] + print(f'Loading [{network_name}] from [{ckpt_path}]...') + model = build_from_cfg(network_cfg) + ckpt = torch.load(ckpt_path) + model.load_state_dict(ckpt['state_dict']) + model = model.to(device) + model.eval() + return model + +def interpolater(model, inputs, scale, padder, iters=1): + + ''' + interpolating with the interpolation model. + + params: + model: nn.Module, the frame interpolation model. + inputs: list, the list of the input frames. + scale: float, the scale of the input frames. + iters: int, the number of iterations of interpolation. The final frames model generating is 2 ** iters * (m - 1) + 1 and m is input frames. + returns: + outputs: list, the list of the output frames. + ''' + + print(f'Start frame interpolation:') + embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) + + for i in range(iters): + print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') + outputs = [inputs[0]] + for in_0, in_1 in zip(inputs[:-1], inputs[1:]): + in_0 = in_0.to(device) + in_1 = in_1.to(device) + with torch.no_grad(): + imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred'] + outputs += [imgt_pred.cpu(), in_1.cpu()] + inputs = outputs + + outputs = padder.unpad(*outputs) + + return outputs + +def write(outputs, input_path, output_path, frame_rate=30): + ''' + write results to the output_path. + ''' + + if osp.exists(output_path) is False: + os.makedirs(output_path) + + + size = outputs[0].shape[2:][::-1] + + _, file_name_with_extension = os.path.split(input_path) + file_name, _ = os.path.splitext(file_name_with_extension) + + save_video_path = f'{output_path}/output_{file_name}.mp4' + writer = cv2.VideoWriter(save_video_path, cv2.VideoWriter_fourcc(*"mp4v"), + frame_rate, size) + + for i, imgt_pred in enumerate(outputs): + imgt_pred = tensor2img(imgt_pred) + imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) + writer.write(imgt_pred) + print(f"Demo video is saved to [{save_video_path}]") + + writer.release() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt', type=str, default='amt-g.pth', help="The pretrained model.") + parser.add_argument('--niters', type=int, default=1, help="Iter of Interpolation. The number of frames will be double after per iter.") + parser.add_argument('--input', default="test.mp4", help="Input video.") + parser.add_argument('--output_path', type=str, default='results', help="Output path.") + parser.add_argument('--frame_rate', type=int, default=30, help="Frames rate of the output video.") + + args = parser.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + ckpt_path = args.ckpt + input_path = args.input + output_path = args.output_path + iters = int(args.niters) + frame_rate = int(args.frame_rate) + + inputs, scale, padder = get_input_video_from_path(input_path, device) + model = load_model(ckpt_path, device) + outputs = interpolater(model, inputs, scale, padder, iters) + write(outputs, input_path, output_path, frame_rate) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/AMT-G.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/AMT-G.py new file mode 100644 index 0000000000000000000000000000000000000000..a24cb1a3704984418788bb1f8f0e9946c87886e3 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/AMT-G.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from networks.blocks.raft import ( + coords_grid, + BasicUpdateBlock, BidirCorrBlock +) +from networks.blocks.feat_enc import ( + LargeEncoder +) +from networks.blocks.ifrnet import ( + resize, + Encoder, + InitDecoder, + IntermediateDecoder +) +from networks.blocks.multi_flow import ( + multi_flow_combine, + MultiFlowDecoder +) + + +class Model(nn.Module): + def __init__(self, + corr_radius=3, + corr_lvls=4, + num_flows=5, + channels=[84, 96, 112, 128], + skip_channels=84): + super(Model, self).__init__() + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + + self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) + self.encoder = Encoder(channels, large=True) + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) + + self.update4 = self._get_updateblock(112, None) + self.update3_low = self._get_updateblock(96, 2.0) + self.update2_low = self._get_updateblock(84, 4.0) + + self.update3_high = self._get_updateblock(96, None) + self.update2_high = self._get_updateblock(84, None) + + self.comb_block = nn.Sequential( + nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), + nn.PReLU(6*self.num_flows), + nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), + ) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, + corr_dim=256, corr_dim2=192, fc_dim=188, + scale_factor=scale_factor, corr_levels=self.corr_levels, + radius=self.radius) + + def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t1_scale = 1. / embt + t0_scale = 1. / (1. - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow + + def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coord = coords_grid(b, h // 8, w // 8, img0.device) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) + corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, + up_flow0_4, up_flow1_4, + embt, downsample=1) + + # residue update with lookup corr + delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + up_flow0_4 = up_flow0_4 + delta_flow0_4 + up_flow1_4 = up_flow1_4 + delta_flow1_4 + ft_3_ = ft_3_ + delta_ft_3_ + + ######################################### the 3rd decoder ######################################### + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + corr_3, flow_3 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_3, up_flow1_3, + embt, downsample=2) + + # residue update with lookup corr + delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3) + delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) + up_flow0_3 = up_flow0_3 + delta_flow0_3 + up_flow1_3 = up_flow1_3 + delta_flow1_3 + ft_2_ = ft_2_ + delta_ft_2_ + + # residue update with lookup corr (hr) + corr_3 = resize(corr_3, scale_factor=2.0) + up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1) + delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3) + ft_2_ += delta_ft_2_ + up_flow0_3 += delta_up_flow_3[:, 0:2] + up_flow1_3 += delta_up_flow_3[:, 2:4] + + ######################################### the 2nd decoder ######################################### + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + corr_2, flow_2 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_2, up_flow1_2, + embt, downsample=4) + + # residue update with lookup corr + delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2) + delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) + up_flow0_2 = up_flow0_2 + delta_flow0_2 + up_flow1_2 = up_flow1_2 + delta_flow1_2 + ft_1_ = ft_1_ + delta_ft_1_ + + # residue update with lookup corr (hr) + corr_2 = resize(corr_2, scale_factor=4.0) + up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1) + delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2) + ft_1_ += delta_ft_1_ + up_flow0_2 += delta_up_flow_2[:, 0:2] + up_flow1_2 += delta_up_flow_2[:, 2:4] + + ######################################### the 1st decoder ######################################### + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + mask = resize(mask, scale_factor=(1.0/scale_factor)) + img_res = resize(img_res, scale_factor=(1.0/scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, + mask, img_res, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + if eval: + return { 'imgt_pred': imgt_pred, } + else: + up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) + up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) + return { + 'imgt_pred': imgt_pred, + 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + 'ft_pred': [ft_1_, ft_2_, ft_3_], + } diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/feat_enc.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/feat_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..3805bd315422703c19bf6a4d0962ee75002d92aa --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/feat_enc.py @@ -0,0 +1,343 @@ +import torch +import torch.nn as nn + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(72, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class LargeEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(LargeEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(112, stride=2) + self.layer3 = self._make_layer(160, stride=2) + self.layer3_2 = self._make_layer(160, stride=1) + + # output convolution + self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer3_2(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/ifrnet.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/ifrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..586ae61036191a52337a791f3e7442899fdf5fc9 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/ifrnet.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils.flow_utils import warp + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + +def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), + nn.PReLU(out_channels) + ) + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + +class Encoder(nn.Module): + def __init__(self, channels, large=False): + super(Encoder, self).__init__() + self.channels = channels + prev_ch = 3 + for idx, ch in enumerate(channels, 1): + k = 7 if large and idx == 1 else 3 + p = 3 if k ==7 else 1 + self.register_module(f'pyramid{idx}', + nn.Sequential( + convrelu(prev_ch, ch, k, 2, p), + convrelu(ch, ch, 3, 1, 1) + )) + prev_ch = ch + + def forward(self, in_x): + fs = [] + for idx in range(len(self.channels)): + out_x = getattr(self, f'pyramid{idx+1}')(in_x) + fs.append(out_x) + in_x = out_x + return fs + +class InitDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch*2+1, in_ch*2), + ResBlock(in_ch*2, skip_ch), + nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True) + ) + def forward(self, f0, f1, embt): + h, w = f0.shape[2:] + embt = embt.repeat(1, 1, h, w) + out = self.convblock(torch.cat([f0, f1, embt], 1)) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + return flow0, flow1, ft_ + +class IntermediateDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True) + ) + def forward(self, ft_, f0, f1, flow0_in, flow1_in): + f0_warp = warp(f0, flow0_in) + f1_warp = warp(f1, flow1_in) + f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) + out = self.convblock(f_in) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) + flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) + return flow0, flow1, ft_ \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/multi_flow.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/multi_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..4563c3262b980ec4489ac96177dea522caa84f21 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/multi_flow.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from utils.flow_utils import warp +from networks.blocks.ifrnet import ( + convrelu, resize, + ResBlock, +) + + +def multi_flow_combine(comb_block, img0, img1, flow0, flow1, + mask=None, img_res=None, mean=None): + ''' + A parallel implementation of multiple flow field warping + comb_block: An nn.Seqential object. + img shape: [b, c, h, w] + flow shape: [b, 2*num_flows, h, w] + mask (opt): + If 'mask' is None, the function conduct a simple average. + img_res (opt): + If 'img_res' is None, the function adds zero instead. + mean (opt): + If 'mean' is None, the function adds zero instead. + ''' + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = mask.reshape(b, num_flows, 1, h, w + ).reshape(-1, 1, h, w) if mask is not None else None + img_res = img_res.reshape(b, num_flows, 3, h, w + ).reshape(-1, 3, h, w) if img_res is not None else 0 + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 + ) if mean is not None else 0 + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) + return imgt_pred + + +class MultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(MultiFlowDecoder, self).__init__() + self.num_flows = num_flows + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) + ) + + def forward(self, ft_, f0, f1, flow0, flow1): + n = self.num_flows + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) + delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) + mask = torch.sigmoid(mask) + + flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + + return flow0, flow1, mask, img_res \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/raft.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb85ad6556a28f5b80034c595be539fd700ad48 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/networks/blocks/raft.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def bilinear_sampler(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), + torch.arange(wd, device=device), + indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, + corr_levels=4, radius=3, scale_factor=None): + super(SmallUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + self.scale_factor = scale_factor + + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + + return delta_net, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, + fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): + super(BasicUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + return delta_net, delta_flow + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.corr_pyramid_T = [] + + corr = BidirCorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) + + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1) + + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + for _ in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_T = F.avg_pool2d(corr_T, 2, stride=2) + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + def __call__(self, coords0, coords1): + r = self.radius + coords0 = coords0.permute(0, 2, 3, 1) + coords1 = coords1.permute(0, 2, 3, 1) + assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" + batch, h1, w1, _ = coords0.shape + + out_pyramid = [] + out_pyramid_T = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + corr_T = self.corr_pyramid_T[i] + + dx = torch.linspace(-r, r, 2*r+1, device=coords0.device) + dy = torch.linspace(-r, r, 2*r+1, device=coords0.device) + delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1) + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + + centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i + centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i + coords_lvl_0 = centroid_lvl_0 + delta_lvl + coords_lvl_1 = centroid_lvl_1 + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl_0) + corr_T = bilinear_sampler(corr_T, coords_lvl_1) + corr = corr.view(batch, h1, w1, -1) + corr_T = corr_T.view(batch, h1, w1, -1) + out_pyramid.append(corr) + out_pyramid_T.append(corr_T) + + out = torch.cat(out_pyramid, dim=-1) + out_T = torch.cat(out_pyramid_T, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/readme.md b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..1f35e92a600f33d52858c6ab9d0b4a7d650908cd --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/readme.md @@ -0,0 +1,17 @@ +#### Frame Interpolation + +We use AMT as our frame interpolation model. (Thanks [AMT](https://github.com/MCG-NKU/AMT)) After sampling, you can use frame interpolation model to interpolate your video smoothly. + +1. Download the pretrained weights from [AMT](https://github.com/MCG-NKU/AMT), we recommend using the largest model AMT-G to achieve the best performance. +2. Run the script of frame interpolation. +``` +python opensora/models/frame_interpolation/interpolation.py --ckpt /path/to/ckpt --niters 1 --input /path/to/input/video.mp4 --output_path /path/to/output/floder --frame_rate 30 +``` +3. The output video will be stored at output_path and its duration time is equal `the total number of frames after frame interpolation / the frame rate` +##### Frame Interpolation Specific Settings + +* `--ckpt`: Pretrained model of [AMT](https://github.com/MCG-NKU/AMT). We use AMT-G as our frame interpolation model. +* `--niter`: Iterations of interpolation. With $m$ input frames, `[N_ITER]` $=n$ corresponds to $2^n\times (m-1)+1$ output frames. +* `--input`: Path of the input video. +* `--output_path`: Folder Path of the output video. +* `--frame_rate"`: Frame rate of the output video. diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/build_utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/build_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d34c3b8d45d97961a175784b1c0a362bed3a508 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/build_utils.py @@ -0,0 +1,12 @@ +import importlib + + +def base_build_fn(module, cls, params): + return getattr(importlib.import_module( + module, package=None), cls)(**params) + + +def build_from_cfg(config): + module, cls = config['name'].rsplit(".", 1) + params = config.get('params', {}) + return base_build_fn(module, cls, params) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/dist_utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6337f9991fc510cfb6cbc7da18574eb443ec1dac --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/dist_utils.py @@ -0,0 +1,48 @@ +import os +import torch + + +def get_world_size(): + """Find OMPI world size without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_SIZE') is not None: + return int(os.environ.get('PMI_SIZE') or 1) + elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) + else: + return torch.cuda.device_count() + + +def get_global_rank(): + """Find OMPI world rank without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_RANK') is not None: + return int(os.environ.get('PMI_RANK') or 0) + elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) + else: + return 0 + + +def get_local_rank(): + """Find OMPI local rank without calling mpi functions + :rtype: int + """ + if os.environ.get('MPI_LOCALRANKID') is not None: + return int(os.environ.get('MPI_LOCALRANKID') or 0) + elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) + else: + return 0 + + +def get_master_ip(): + if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] + elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') + else: + return "127.0.0.1" + diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/flow_utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84fca2049783b22175e0d1e024a19a5f9a79906e --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/flow_utils.py @@ -0,0 +1,122 @@ +import numpy as np +import torch +from PIL import ImageFile +import torch.nn.functional as F +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) + return output + + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0473226d4eaf98e41e7ae3ee81b722308765e96c --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/frame_interpolation/utils/utils.py @@ -0,0 +1,297 @@ +import re +import sys +import torch +import random +import numpy as np +from PIL import ImageFile +import torch.nn.functional as F +from imageio import imread, imwrite +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class AverageMeter(): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0. + self.avg = 0. + self.sum = 0. + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class AverageMeterGroups: + def __init__(self) -> None: + self.meter_dict = dict() + + def update(self, dict, n=1): + for name, val in dict.items(): + if self.meter_dict.get(name) is None: + self.meter_dict[name] = AverageMeter() + self.meter_dict[name].update(val, n) + + def reset(self, name=None): + if name is None: + for v in self.meter_dict.values(): + v.reset() + else: + meter = self.meter_dict.get(name) + if meter is not None: + meter.reset() + + def avg(self, name): + meter = self.meter_dict.get(name) + if meter is not None: + return meter.avg + + +class InputPadder: + """ Pads images such that dimensions are divisible by divisor """ + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + + def pad(self, *inputs): + if len(inputs) == 1: + return F.pad(inputs[0], self._pad, mode='replicate') + else: + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, *inputs): + if len(inputs) == 1: + return self._unpad(inputs[0]) + else: + return [self._unpad(x) for x in inputs] + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +def img2tensor(img): + if img.shape[-1] > 3: + img = img[:,:,:3] + return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 + + +def tensor2img(img_t): + return (img_t * 255.).detach( + ).squeeze(0).permute(1, 2, 0).cpu().numpy( + ).clip(0, 255).astype(np.uint8) + +def seed_all(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def read(file): + if file.endswith('.float3'): return readFloat(file) + elif file.endswith('.flo'): return readFlow(file) + elif file.endswith('.ppm'): return readImage(file) + elif file.endswith('.pgm'): return readImage(file) + elif file.endswith('.png'): return readImage(file) + elif file.endswith('.jpg'): return readImage(file) + elif file.endswith('.pfm'): return readPFM(file)[0] + else: raise Exception('don\'t know how to read %s' % file) + + +def write(file, data): + if file.endswith('.float3'): return writeFloat(file, data) + elif file.endswith('.flo'): return writeFlow(file, data) + elif file.endswith('.ppm'): return writeImage(file, data) + elif file.endswith('.pgm'): return writeImage(file, data) + elif file.endswith('.png'): return writeImage(file, data) + elif file.endswith('.jpg'): return writeImage(file, data) + elif file.endswith('.pfm'): return writePFM(file, data) + else: raise Exception('don\'t know how to write %s' % file) + + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == 'PF': + color = True + elif header.decode("ascii") == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + endian = '<' + scale = -scale + else: + endian = '>' + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + + +def writePFM(file, image, scale=1): + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n'.encode()) + file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n'.encode() % scale) + + image.tofile(file) + + +def readFlow(name): + if name.endswith('.pfm') or name.endswith('.PFM'): + return readPFM(name)[0][:,:,0:2] + + f = open(name, 'rb') + + header = f.read(4) + if header.decode("utf-8") != 'PIEH': + raise Exception('Flow file header does not contain PIEH') + + width = np.fromfile(f, np.int32, 1).squeeze() + height = np.fromfile(f, np.int32, 1).squeeze() + + flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) + + return flow.astype(np.float32) + + +def readImage(name): + if name.endswith('.pfm') or name.endswith('.PFM'): + data = readPFM(name)[0] + if len(data.shape)==3: + return data[:,:,0:3] + else: + return data + return imread(name) + + +def writeImage(name, data): + if name.endswith('.pfm') or name.endswith('.PFM'): + return writePFM(name, data, 1) + return imwrite(name, data) + + +def writeFlow(name, flow): + f = open(name, 'wb') + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + + +def readFloat(name): + f = open(name, 'rb') + + if(f.readline().decode("utf-8")) != 'float\n': + raise Exception('float file %s did not contain keyword' % name) + + dim = int(f.readline()) + + dims = [] + count = 1 + for i in range(0, dim): + d = int(f.readline()) + dims.append(d) + count *= d + + dims = list(reversed(dims)) + + data = np.fromfile(f, np.float32, count).reshape(dims) + if dim > 2: + data = np.transpose(data, (2, 1, 0)) + data = np.transpose(data, (1, 0, 2)) + + return data + + +def writeFloat(name, data): + f = open(name, 'wb') + + dim=len(data.shape) + if dim>3: + raise Exception('bad float file dimension: %d' % dim) + + f.write(('float\n').encode('ascii')) + f.write(('%d\n' % dim).encode('ascii')) + + if dim == 1: + f.write(('%d\n' % data.shape[0]).encode('ascii')) + else: + f.write(('%d\n' % data.shape[1]).encode('ascii')) + f.write(('%d\n' % data.shape[0]).encode('ascii')) + for i in range(2, dim): + f.write(('%d\n' % data.shape[i]).encode('ascii')) + + data = data.astype(np.float32) + if dim==2: + data.tofile(f) + + else: + np.transpose(data, (2, 0, 1)).tofile(f) + + +def check_dim_and_resize(tensor_list): + shape_list = [] + for t in tensor_list: + shape_list.append(t.shape[2:]) + + if len(set(shape_list)) > 1: + desired_shape = shape_list[0] + print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}') + + resize_tensor_list = [] + for t in tensor_list: + resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear')) + + tensor_list = resize_tensor_list + + return tensor_list + diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/README.md b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/README.md new file mode 100644 index 0000000000000000000000000000000000000000..869cee78d438905cce91d852d20b8a858a2cf25b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/README.md @@ -0,0 +1,25 @@ + +## Environment Preparation + +For video super resolution, please prepare your own python envirment from [RGT](https://github.com/zhengchen1999/RGT) and down the [ckpt](https://drive.google.com/drive/folders/1zxrr31Kp2D_N9a-OUAPaJEn_yTaSXTfZ) into the folder like +```bash +./experiments/pretrained_models/RGT_x2.pth +``` + +## Video Super Resolution +The inferencing instruction is in [run.py](run.py). +```bash +python run.py --SR x4 --root_path /path_to_root --input_dir /path_to_input_dir --output_dir /path_to_video_output +``` +You can configure some more detailed parameters in [run.py](run.py) such as . +```bash +--mul_numwork 16 --use_chop False +``` +We recommend using `` --use_chop = False `` when memory allows. +Note that in our tests. + +A single frame of 256x256 requires about 3G RAM-Usage, and a single 4090 card can process about one frame per second. + +A single frame of 512x512 takes about 19G RAM-Usage, and a single 4090 takes about 5 seconds to process a frame. + + diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3360401f72406374b62ec8da2625d2f293e37687 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/__init__.py @@ -0,0 +1,6 @@ +from .archs import * +from .data import * +from .metrics import * +from .models import * +from .test import * +from .utils import * diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb1e4d7bb221c429082bd389d9140e5b1cc07b0 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/arch_util.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..1719e8e9fe66cd0adc667a76646bcd9dfe588d5e --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/arch_util.py @@ -0,0 +1,318 @@ +import collections.abc +import math +import torch +import torchvision +import warnings +from distutils.version import LooseVersion +from itertools import repeat +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +# from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from basicsr.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +# class DCNv2Pack(ModulatedDeformConvPack): +# """Modulated deformable conv for deformable alignment. +# +# Different from the official DCNv2Pack, which generates offsets and masks +# from the preceding features, this DCNv2Pack takes another different +# features to generate offsets and masks. +# +# Ref: +# Delving Deep into Deformable Alignment in Video Super-Resolution. +# """ +# +# def forward(self, x, feat): +# out = self.conv_offset(feat) +# o1, o2, mask = torch.chunk(out, 3, dim=1) +# offset = torch.cat((o1, o2), dim=1) +# mask = torch.sigmoid(mask) +# +# offset_absmean = torch.mean(torch.abs(offset)) +# if offset_absmean > 50: +# logger = get_root_logger() +# logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') +# +# if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): +# return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, +# self.dilation, mask) +# else: +# return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, +# self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/rgt_arch.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/rgt_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..810fd765a8a59bd1ddf8ff010441b0cd7ee6126e --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/rgt_arch.py @@ -0,0 +1,757 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from torch import Tensor +from torch.nn import functional as F + +from timm.models.layers import DropPath, trunc_normal_ +from einops.layers.torch import Rearrange +from einops import rearrange, repeat + +import math +import numpy as np + +import random + +from basicsr.utils.registry import ARCH_REGISTRY + + +def img2windows(img, H_sp, W_sp): + """ + Input: Image (B, C, H, W) + Output: Window Partition (B', N, C) + """ + B, C, H, W = img.shape + img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) + img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) + return img_perm + + +def windows2img(img_splits_hw, H_sp, W_sp, H, W): + """ + Input: Window Partition (B', N, C) + Output: Image (B, H, W, C) + """ + B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) + + img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) + img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return img + + +class Gate(nn.Module): + def __init__(self, dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) # DW Conv + + def forward(self, x, H, W): + # Split + x1, x2 = x.chunk(2, dim = -1) + B, N, C = x.shape + x2 = self.conv(self.norm(x2).transpose(1, 2).contiguous().view(B, C//2, H, W)).flatten(2).transpose(-1, -2).contiguous() + + return x1 * x2 + + +class MLP(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.sg = Gate(hidden_features//2) + self.fc2 = nn.Linear(hidden_features//2, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + """ + Input: x: (B, H*W, C), H, W + Output: x: (B, H*W, C) + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + + x = self.sg(x, H, W) + x = self.drop(x) + + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py + """ Dynamic Relative Position Bias. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + residual (bool): If True, use residual strage to connect conv. + """ + def __init__(self, dim, num_heads, residual): + super().__init__() + self.residual = residual + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + def forward(self, biases): + if self.residual: + pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads + pos = pos + self.pos1(pos) + pos = pos + self.pos2(pos) + pos = self.pos3(pos) + else: + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + +class WindowAttention(nn.Module): + def __init__(self, dim, idx, split_size=[8,8], dim_out=None, num_heads=6, attn_drop=0., proj_drop=0., qk_scale=None, position_bias=True): + super().__init__() + self.dim = dim + self.dim_out = dim_out or dim + self.split_size = split_size + self.num_heads = num_heads + self.idx = idx + self.position_bias = position_bias + + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + if idx == 0: + H_sp, W_sp = self.split_size[0], self.split_size[1] + elif idx == 1: + W_sp, H_sp = self.split_size[0], self.split_size[1] + else: + print ("ERROR MODE", idx) + exit(0) + self.H_sp = H_sp + self.W_sp = W_sp + + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False) + # generate mother-set + position_bias_h = torch.arange(1 - self.H_sp, self.H_sp) + position_bias_w = torch.arange(1 - self.W_sp, self.W_sp) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) + biases = biases.flatten(1).transpose(0, 1).contiguous().float() + self.register_buffer('rpe_biases', biases) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.H_sp) + coords_w = torch.arange(self.W_sp) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.H_sp - 1 + relative_coords[:, :, 1] += self.W_sp - 1 + relative_coords[:, :, 0] *= 2 * self.W_sp - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer('relative_position_index', relative_position_index) + + self.attn_drop = nn.Dropout(attn_drop) + + def im2win(self, x, H, W): + B, N, C = x.shape + x = x.transpose(-2,-1).contiguous().view(B, C, H, W) + x = img2windows(x, self.H_sp, self.W_sp) + x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous() + return x + + def forward(self, qkv, H, W, mask=None): + """ + Input: qkv: (B, 3*L, C), H, W, mask: (B, N, N), N is the window size + Output: x (B, H, W, C) + """ + q,k,v = qkv[0], qkv[1], qkv[2] + + B, L, C = q.shape + assert L == H * W, "flatten img_tokens has wrong size" + + # partition the q,k,v, image to window + q = self.im2win(q, H, W) + k = self.im2win(k, H, W) + v = self.im2win(v, H, W) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N + + # calculate drpe + if self.position_bias: + pos = self.pos(self.rpe_biases) + # select position bias + relative_position_bias = pos[self.relative_position_index.view(-1)].view( + self.H_sp * self.W_sp, self.H_sp * self.W_sp, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + N = attn.shape[3] + + # use mask for shift window + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype) + attn = self.attn_drop(attn) + + x = (attn @ v) + x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C) # B head N N @ B head N C + + # merge the window, window to image + x = windows2img(x, self.H_sp, self.W_sp, H, W) # B H' W' C + + return x + + +class L_SA(nn.Module): + # The implementation builds on CAT code https://github.com/zhengchen1999/CAT/blob/main/basicsr/archs/cat_arch.py + def __init__(self, dim, num_heads, + split_size=[2,4], shift_size=[1,2], qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., idx=0, reso=64, rs_id=0): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.split_size = split_size + self.shift_size = shift_size + self.idx = idx + self.rs_id = rs_id + self.patches_resolution = reso + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + assert 0 <= self.shift_size[0] < self.split_size[0], "shift_size must in 0-split_size0" + assert 0 <= self.shift_size[1] < self.split_size[1], "shift_size must in 0-split_size1" + + self.branch_num = 2 + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(drop) + + self.attns = nn.ModuleList([ + WindowAttention( + dim//2, idx = i, + split_size=split_size, num_heads=num_heads//2, dim_out=dim//2, + qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, position_bias=True) + for i in range(self.branch_num)]) + + if (self.rs_id % 2 == 0 and self.idx > 0 and (self.idx - 2) % 4 == 0) or (self.rs_id % 2 != 0 and self.idx % 4 == 0): + attn_mask = self.calculate_mask(self.patches_resolution, self.patches_resolution) + + self.register_buffer("attn_mask_0", attn_mask[0]) + self.register_buffer("attn_mask_1", attn_mask[1]) + else: + attn_mask = None + + self.register_buffer("attn_mask_0", None) + self.register_buffer("attn_mask_1", None) + + self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim) # DW Conv + + def calculate_mask(self, H, W): + # The implementation builds on Swin Transformer code https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for Rwin + img_mask_0 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=0 + img_mask_1 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=1 + h_slices_0 = (slice(0, -self.split_size[0]), + slice(-self.split_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices_0 = (slice(0, -self.split_size[1]), + slice(-self.split_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + + h_slices_1 = (slice(0, -self.split_size[1]), + slice(-self.split_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + w_slices_1 = (slice(0, -self.split_size[0]), + slice(-self.split_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + cnt = 0 + for h in h_slices_0: + for w in w_slices_0: + img_mask_0[:, h, w, :] = cnt + cnt += 1 + cnt = 0 + for h in h_slices_1: + for w in w_slices_1: + img_mask_1[:, h, w, :] = cnt + cnt += 1 + + # calculate mask for H-Shift + img_mask_0 = img_mask_0.view(1, H // self.split_size[0], self.split_size[0], W // self.split_size[1], self.split_size[1], 1) + img_mask_0 = img_mask_0.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[0], self.split_size[1], 1) # nW, sw[0], sw[1], 1 + mask_windows_0 = img_mask_0.view(-1, self.split_size[0] * self.split_size[1]) + attn_mask_0 = mask_windows_0.unsqueeze(1) - mask_windows_0.unsqueeze(2) + attn_mask_0 = attn_mask_0.masked_fill(attn_mask_0 != 0, float(-100.0)).masked_fill(attn_mask_0 == 0, float(0.0)) + + # calculate mask for V-Shift + img_mask_1 = img_mask_1.view(1, H // self.split_size[1], self.split_size[1], W // self.split_size[0], self.split_size[0], 1) + img_mask_1 = img_mask_1.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[1], self.split_size[0], 1) # nW, sw[1], sw[0], 1 + mask_windows_1 = img_mask_1.view(-1, self.split_size[1] * self.split_size[0]) + attn_mask_1 = mask_windows_1.unsqueeze(1) - mask_windows_1.unsqueeze(2) + attn_mask_1 = attn_mask_1.masked_fill(attn_mask_1 != 0, float(-100.0)).masked_fill(attn_mask_1 == 0, float(0.0)) + + return attn_mask_0, attn_mask_1 + + def forward(self, x, H, W): + """ + Input: x: (B, H*W, C), x_size: (H, W) + Output: x: (B, H*W, C) + """ + + B, L, C = x.shape + assert L == H * W, "flatten img_tokens has wrong size" + + qkv = self.qkv(x).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # 3, B, HW, C + # v without partition + v = qkv[2].transpose(-2,-1).contiguous().view(B, C, H, W) + + + max_split_size = max(self.split_size[0], self.split_size[1]) + pad_l = pad_t = 0 + pad_r = (max_split_size - W % max_split_size) % max_split_size + pad_b = (max_split_size - H % max_split_size) % max_split_size + + qkv = qkv.reshape(3*B, H, W, C).permute(0, 3, 1, 2) # 3B C H W + qkv = F.pad(qkv, (pad_l, pad_r, pad_t, pad_b)).reshape(3, B, C, -1).transpose(-2, -1) # l r t b + _H = pad_b + H + _W = pad_r + W + _L = _H * _W + + if (self.rs_id % 2 == 0 and self.idx > 0 and (self.idx - 2) % 4 == 0) or (self.rs_id % 2 != 0 and self.idx % 4 == 0): + qkv = qkv.view(3, B, _H, _W, C) + # H-Shift + qkv_0 = torch.roll(qkv[:,:,:,:,:C//2], shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(2, 3)) + qkv_0 = qkv_0.view(3, B, _L, C//2) + # V-Shift + qkv_1 = torch.roll(qkv[:,:,:,:,C//2:], shifts=(-self.shift_size[1], -self.shift_size[0]), dims=(2, 3)) + qkv_1 = qkv_1.view(3, B, _L, C//2) + + if self.patches_resolution != _H or self.patches_resolution != _W: + mask_tmp = self.calculate_mask(_H, _W) + # H-Rwin + x1_shift = self.attns[0](qkv_0, _H, _W, mask=mask_tmp[0].to(x.device)) + # V-Rwin + x2_shift = self.attns[1](qkv_1, _H, _W, mask=mask_tmp[1].to(x.device)) + + else: + # H-Rwin + x1_shift = self.attns[0](qkv_0, _H, _W, mask=self.attn_mask_0) + # V-Rwin + x2_shift = self.attns[1](qkv_1, _H, _W, mask=self.attn_mask_1) + + x1 = torch.roll(x1_shift, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)) + x2 = torch.roll(x2_shift, shifts=(self.shift_size[1], self.shift_size[0]), dims=(1, 2)) + x1 = x1[:, :H, :W, :].reshape(B, L, C//2) + x2 = x2[:, :H, :W, :].reshape(B, L, C//2) + # Concat + attened_x = torch.cat([x1,x2], dim=2) + else: + # V-Rwin + x1 = self.attns[0](qkv[:,:,:,:C//2], _H, _W)[:, :H, :W, :].reshape(B, L, C//2) + # H-Rwin + x2 = self.attns[1](qkv[:,:,:,C//2:], _H, _W)[:, :H, :W, :].reshape(B, L, C//2) + # Concat + attened_x = torch.cat([x1,x2], dim=2) + + # mix + lcm = self.get_v(v) + lcm = lcm.permute(0, 2, 3, 1).contiguous().view(B, L, C) + + x = attened_x + lcm + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class RG_SA(nn.Module): + """ + Recursive-Generalization Self-Attention (RG-SA). + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + c_ratio (float): channel adjustment factor. + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., c_ratio=0.5): + super(RG_SA, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + self.num_heads = num_heads + head_dim = dim // num_heads + + self.cr = int(dim * c_ratio) # scaled channel dimension + + # self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or (head_dim * c_ratio) ** -0.5 + + # RGM + self.reduction1 = nn.Conv2d(dim, dim, kernel_size=4, stride=4, groups=dim) + self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim) + self.conv = nn.Conv2d(dim, self.cr, kernel_size=1, stride=1) + self.norm_act = nn.Sequential( + nn.LayerNorm(self.cr), + nn.GELU()) + # CA + self.q = nn.Linear(dim, self.cr, bias=qkv_bias) + self.k = nn.Linear(self.cr, self.cr, bias=qkv_bias) + self.v = nn.Linear(self.cr, dim, bias=qkv_bias) + + # CPE + self.cpe = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) + + self.proj = nn.Linear(dim, dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W): + B, N, C = x.shape + + _scale = 1 + + # reduction + _x = x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() + + if self.training: + _time = max(int(math.log(H//4, 4)), int(math.log(W//4, 4))) + else: + _time = max(int(math.log(H//16, 4)), int(math.log(W//16, 4))) + if _time < 2: _time = 2 # testing _time must equal or larger than training _time (2) + + _scale = 4 ** _time + + # Recursion xT + for _ in range(_time): + _x = self.reduction1(_x) + + _x = self.conv(self.dwconv(_x)).reshape(B, self.cr, -1).permute(0, 2, 1).contiguous() # shape=(B, N', C') + _x = self.norm_act(_x) + + # q, k, v, where q_shape=(B, N, C'), k_shape=(B, N', C'), v_shape=(B, N', C) + q = self.q(x).reshape(B, N, self.num_heads, int(self.cr / self.num_heads)).permute(0, 2, 1, 3) + k = self.k(_x).reshape(B, -1, self.num_heads, int(self.cr / self.num_heads)).permute(0, 2, 1, 3) + v = self.v(_x).reshape(B, -1, self.num_heads, int(C / self.num_heads)).permute(0, 2, 1, 3) + + # corss-attention + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # CPE + # v_shape=(B, H, N', C//H) + v = v + self.cpe(v.transpose(1, 2).reshape(B, -1, C).transpose(1, 2).contiguous().view(B, C, H // _scale, W // _scale)).view(B, C, -1).view(B, self.num_heads, int(C / self.num_heads), -1).transpose(-1, -2) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., + attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, idx=0, + rs_id=0, split_size=[2,4], shift_size=[1,2], reso=64, c_ratio=0.5, layerscale_value=1e-4): + super().__init__() + self.norm1 = norm_layer(dim) + if idx % 2 == 0: + self.attn = L_SA( + dim, split_size=split_size, shift_size=shift_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + drop=drop, idx=idx, reso=reso, rs_id=rs_id + ) + else: + self.attn = RG_SA( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop, c_ratio=c_ratio + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer) + self.norm2 = norm_layer(dim) + + # HAI + self.gamma = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) + + def forward(self, x, x_size): + H , W = x_size + + res = x + + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + # HAI + x = x + (res * self.gamma) + + return x + + +class ResidualGroup(nn.Module): + + def __init__( self, + dim, + reso, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_paths=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + depth=2, + use_chk=False, + resi_connection='1conv', + rs_id=0, + split_size=[8,8], + c_ratio = 0.5): + super().__init__() + self.use_chk = use_chk + self.reso = reso + + self.blocks = nn.ModuleList([ + Block( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_paths[i], + act_layer=act_layer, + norm_layer=norm_layer, + idx = i, + rs_id = rs_id, + split_size = split_size, + shift_size = [split_size[0]//2, split_size[1]//2], + c_ratio = c_ratio, + )for i in range(depth)]) + + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + self.conv = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + def forward(self, x, x_size): + """ + Input: x: (B, H*W, C), x_size: (H, W) + Output: x: (B, H*W, C) + """ + H, W = x_size + res = x + for blk in self.blocks: + if self.use_chk: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) + x = self.conv(x) + x = rearrange(x, "b c h w -> b (h w) c") + x = res + x + + return x + + +class Upsample(nn.Sequential): + """Upsample module. + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +@ARCH_REGISTRY.register() +class RGT(nn.Module): + + def __init__(self, + img_size=64, + in_chans=3, + embed_dim=180, + depth=[2,2,2,2], + num_heads=[2,2,2,2], + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_chk=False, + upscale=2, + img_range=1., + resi_connection='1conv', + split_size=[8,8], + c_ratio=0.5, + **kwargs): + super().__init__() + + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + + # ------------------------- 1, Shallow Feature Extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, Deep Feature Extraction ------------------------- # + self.num_layers = len(depth) + self.use_chk = use_chk + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + heads=num_heads + + self.before_RG = nn.Sequential( + Rearrange('b c h w -> b (h w) c'), + nn.LayerNorm(embed_dim) + ) + + curr_dim = embed_dim + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))] # stochastic depth decay rule + + self.layers = nn.ModuleList() + for i in range(self.num_layers): + layer = ResidualGroup( + dim=embed_dim, + num_heads=heads[i], + reso=img_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_paths=dpr[sum(depth[:i]):sum(depth[:i + 1])], + act_layer=act_layer, + norm_layer=norm_layer, + depth=depth[i], + use_chk=use_chk, + resi_connection=resi_connection, + rs_id=i, + split_size = split_size, + c_ratio = c_ratio + ) + self.layers.append(layer) + + self.norm = norm_layer(curr_dim) + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + # ------------------------- 3, Reconstruction ------------------------- # + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + _, _, H, W = x.shape + x_size = [H, W] + x = self.before_RG(x) + for layer in self.layers: + x = layer(x, x_size) + x = self.norm(x) + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) + + return x + + def forward(self, x): + """ + Input: x: (B, C, H, W) + """ + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + + x = x / self.img_range + self.mean + return x + + +if __name__ == '__main__': + upscale = 1 + height = 62 + width = 66 + model = RGT( + upscale=2, + in_chans=3, + img_size=64, + img_range=1., + depth=[6,6,6,6,6,6], + embed_dim=180, + num_heads=[6,6,6,6,6,6], + mlp_ratio=2, + resi_connection='1conv', + split_size=[8, 8], + upsampler='pixelshuffle').cuda() + # print(model) + print(height, width) + + x = torch.randn((1, 3, height, width)).cuda() + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/vgg_arch.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/vgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1e44c0a11c355fde5847ed9b42dae2ad7e42ea9d --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/archs/vgg_arch.py @@ -0,0 +1,161 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +from basicsr.utils.registry import ARCH_REGISTRY + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +@ARCH_REGISTRY.register() +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + + output = {} + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12a7ea0538bb3ff33755a7440bb7aa963c67efdf --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/__init__.py @@ -0,0 +1,101 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from basicsr.data.prefetch_dataloader import PrefetchDataLoader +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.dist_util import get_dist_info +from basicsr.utils.registry import DATASET_REGISTRY + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must contain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + # logger = get_root_logger() + # logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/data_sampler.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/data_util.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..de4811dad4ad295be6c992c997412c46920d6583 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/data_util.py @@ -0,0 +1,280 @@ +import cv2 +import numpy as np +import torch +from os import path as osp +from torch.nn import functional as F + +from basicsr.utils import img2tensor, scandir + + +def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'): + """Generate an index list for reading `num_frames` frames from a sequence + of images. + + Args: + crt_idx (int): Current center index. + max_frame_num (int): Max number of the sequence of images (from 1). + num_frames (int): Reading num_frames frames. + padding (str): Padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle' + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + list[int]: A list of indices. + """ + assert num_frames % 2 == 1, 'num_frames should be an odd number.' + assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.' + + max_frame_num = max_frame_num - 1 # start from 0 + num_pad = num_frames // 2 + + indices = [] + for i in range(crt_idx - num_pad, crt_idx + num_pad + 1): + if i < 0: + if padding == 'replicate': + pad_idx = 0 + elif padding == 'reflection': + pad_idx = -i + elif padding == 'reflection_circle': + pad_idx = crt_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if padding == 'replicate': + pad_idx = max_frame_num + elif padding == 'reflection': + pad_idx = max_frame_num * 2 - i + elif padding == 'reflection_circle': + pad_idx = (crt_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + indices.append(pad_idx) + return indices + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): + raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb ' + f'formats. But received {input_key}: {input_folder}; ' + f'{gt_key}: {gt_folder}') + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, 'meta_info.txt')) as fin: + input_lmdb_keys = [line.split('.')[0] for line in fin] + with open(osp.join(gt_folder, 'meta_info.txt')) as fin: + gt_lmdb_keys = [line.split('.')[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.') + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.strip().split(' ')[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.' + gt_path = osp.join(gt_folder, gt_path) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + +def generate_gaussian_kernel(kernel_size=13, sigma=1.6): + """Generate Gaussian kernel used in `duf_downsample`. + + Args: + kernel_size (int): Kernel size. Default: 13. + sigma (float): Sigma of the Gaussian kernel. Default: 1.6. + + Returns: + np.array: The Gaussian kernel. + """ + from scipy.ndimage import filters as filters + kernel = np.zeros((kernel_size, kernel_size)) + # set element at the middle to one, a dirac delta + kernel[kernel_size // 2, kernel_size // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter + return filters.gaussian_filter(kernel, sigma) + + +def duf_downsample(x, kernel_size=13, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code. + + Args: + x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w). + kernel_size (int): Kernel size. Default: 13. + scale (int): Downsampling factor. Supported scale: (2, 3, 4). + Default: 4. + + Returns: + Tensor: DUF downsampled frames. + """ + assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.' + + squeeze_flag = False + if x.ndim == 4: + squeeze_flag = True + x = x.unsqueeze(0) + b, t, c, h, w = x.size() + x = x.view(-1, 1, h, w) + pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2 + x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect') + + gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale) + gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(b, t, c, x.size(2), x.size(3)) + if squeeze_flag: + x = x.squeeze(0) + return x diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/paired_image_dataset.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/paired_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..23749f8d5d314b90482903cd6430658356b6f1fa --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/paired_image_dataset.py @@ -0,0 +1,113 @@ +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, imfrombytes, img2tensor +from basicsr.utils.matlab_functions import bgr2ycbcr +from basicsr.utils.registry import DATASET_REGISTRY + +import numpy as np + +@DATASET_REGISTRY.register() +class PairedImageDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info_file': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(PairedImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + if 'filename_tmpl' in opt: + self.filename_tmpl = opt['filename_tmpl'] + else: + self.filename_tmpl = '{}' + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: + self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], + self.opt['meta_info_file'], self.filename_tmpl) + else: + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + + # image range: [0, 1], float32., H W 3 + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] + img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] + + # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets + # TODO: It is better to update the datasets, rather than force to crop + if self.opt['phase'] != 'train': + img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + # print(img_lq.shape,img_gt.shape,img_lq.min(),img_gt.min(),img_lq.max(),img_gt.max(),lq_path,gt_path) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/prefetch_dataloader.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/prefetch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/single_image_dataset.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/single_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..795803a10f02c649834c1daed7a87804a8426305 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/single_image_dataset.py @@ -0,0 +1,69 @@ +from os import path as osp +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.data_util import paths_from_lmdb +from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir +from basicsr.utils.matlab_functions import rgb2ycbcr +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class SingleImageDataset(data.Dataset): + """Read only lq images in the test phase. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). + + There are two modes: + 1. 'meta_info_file': Use meta information file to generate paths. + 2. 'folder': Scan folders to generate paths. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + """ + + def __init__(self, opt): + super(SingleImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + self.lq_folder = opt['dataroot_lq'] + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder] + self.io_backend_opt['client_keys'] = ['lq'] + self.paths = paths_from_lmdb(self.lq_folder) + elif 'meta_info_file' in self.opt: + with open(self.opt['meta_info_file'], 'r') as fin: + self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] + else: + self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load lq image + lq_path = self.paths[index] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + return {'lq': img_lq, 'lq_path': lq_path} + + def __len__(self): + return len(self.paths) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/transforms.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bbb5fb7daef5edfb425fafb4d67d471b3001e6 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/data/transforms.py @@ -0,0 +1,179 @@ +import cv2 +import random +import torch + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): + """Paired random crop. Support Numpy array and Tensor inputs. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. Default: None. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + # determine input type: Numpy array or Tensor + input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' + + if input_type == 'Tensor': + h_lq, w_lq = img_lqs[0].size()[-2:] + h_gt, w_gt = img_gts[0].size()[-2:] + else: + h_lq, w_lq = img_lqs[0].shape[0:2] + h_gt, w_gt = img_gts[0].shape[0:2] + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + if input_type == 'Tensor': + img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] + else: + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + if input_type == 'Tensor': + img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] + else: + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/losses/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14942900c1abc657d6ea649446ec13c2fbb39387 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/losses/__init__.py @@ -0,0 +1,26 @@ +from copy import deepcopy + +from basicsr.utils import get_root_logger +from basicsr.utils.registry import LOSS_REGISTRY +from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, WeightedTVLoss, g_path_regularize, + gradient_penalty_loss, r1_penalty) + +__all__ = [ + 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'GANLoss', 'gradient_penalty_loss', + 'r1_penalty', 'g_path_regularize' +] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/losses/loss_util.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/losses/loss_util.py new file mode 100644 index 0000000000000000000000000000000000000000..744eeb46d1f3b5a7b4553ca23237ddd9c899a698 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/losses/loss_util.py @@ -0,0 +1,95 @@ +import functools +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/losses/losses.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..55436902eb67b456a39742c7eda75b73471e5f5b --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/losses/losses.py @@ -0,0 +1,492 @@ +import math +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.archs.vgg_arch import VGGFeatureExtractor +from basicsr.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + if reduction not in ['mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum') + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction) + + def forward(self, pred, weight=None): + if weight is None: + y_weight = None + x_weight = None + else: + y_weight = weight[:, :, :-1, :] + x_weight = weight[:, :, :, :-1] + + y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight) + x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +@LOSS_REGISTRY.register() +class MultiScaleGANLoss(GANLoss): + """ + MultiScaleGANLoss accepts a list of predictions + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) + + def forward(self, input, target_is_real, is_disc=False): + """ + The input is a list of tensors, or a list of (a list of tensors) + """ + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + # Only compute GAN loss for the last layer + # in case of multiscale feature matching + pred_i = pred_i[-1] + # Safe operation: 0-dim tensor calling self.mean() does nothing + loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() + loss += loss_tensor + return loss / len(input) + else: + return super().forward(input, target_is_real, is_disc) + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty + + +@LOSS_REGISTRY.register() +class GANFeatLoss(nn.Module): + """Define feature matching loss for gans + + Args: + criterion (str): Support 'l1', 'l2', 'charbonnier'. + loss_weight (float): Loss weight. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, criterion='l1', loss_weight=1.0, reduction='mean'): + super(GANFeatLoss, self).__init__() + if criterion == 'l1': + self.loss_op = L1Loss(loss_weight, reduction) + elif criterion == 'l2': + self.loss_op = MSELoss(loss_weight, reduction) + elif criterion == 'charbonnier': + self.loss_op = CharbonnierLoss(loss_weight, reduction) + else: + raise ValueError(f'Unsupported loss mode: {criterion}. Supported ones are: l1|l2|charbonnier') + + self.loss_weight = loss_weight + + def forward(self, pred_fake, pred_real): + num_d = len(pred_fake) + loss = 0 + for i in range(num_d): # for each discriminator + # last output is the final prediction, exclude it + num_intermediate_outputs = len(pred_fake[i]) - 1 + for j in range(num_intermediate_outputs): # for each layer output + unweighted_loss = self.loss_op(pred_fake[i][j], pred_real[i][j].detach()) + loss += unweighted_loss / num_d + return loss * self.loss_weight diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/metrics/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c65580aec1f868da07a9b2c0237214e4899a2736 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/metrics/__init__.py @@ -0,0 +1,19 @@ +from copy import deepcopy + +from basicsr.utils.registry import METRIC_REGISTRY +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/metrics/metric_util.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/metrics/metric_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0b21777874f18a7e87c67153ee92dec4d7b599e8 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from basicsr.utils.matlab_functions import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/metrics/psnr_ssim.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..cb00426b91e200f9458eb9863ed7594d430002ae --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,128 @@ +import cv2 +import numpy as np + +from basicsr.metrics.metric_util import reorder_image, to_y_channel +from basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) + return ssim_map.mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..285ce3ef90550f5cd6cb61467388f8ae4b73f14a --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/__init__.py @@ -0,0 +1,30 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must contain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/base_model.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f06f9ca2ca213f1a7c400355e9c66eaa12b1b1c4 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/base_model.py @@ -0,0 +1,380 @@ +import os +import time +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from basicsr.models import lr_scheduler as lr_scheduler +from basicsr.utils import get_root_logger +from basicsr.utils.dist_util import master_only + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def _initialize_best_metric_results(self, dataset_name): + """Initialize the best metric results dict for recording the best metric value and iteration.""" + if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: + return + elif not hasattr(self, 'best_metric_results'): + self.best_metric_results = dict() + + # add a dataset record + record = dict() + for metric, content in self.opt['val']['metrics'].items(): + better = content.get('better', 'higher') + init_val = float('-inf') if better == 'higher' else float('inf') + record[metric] = dict(better=better, val=init_val, iter=-1) + self.best_metric_results[dataset_name] = record + + def _update_best_metric_result(self, dataset_name, metric, val, current_iter): + if self.best_metric_results[dataset_name][metric]['better'] == 'higher': + if val >= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + else: + if val <= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger = get_root_logger() + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warmup iter numbers. -1 for no warmup. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(save_dict, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with different name or different size when loading models. + + 1. Print keys with different names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + logger = get_root_logger() + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + logger = get_root_logger() + net = self.get_bare_model(net) + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(state, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/lr_scheduler.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..11e1c6c7a74f5233accda52370f92681d3d3cecf --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/lr_scheduler.py @@ -0,0 +1,96 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The minimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/rgt_model.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/rgt_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f63b2351380c679b42cf21ffb0092566c5bd28cd --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/rgt_model.py @@ -0,0 +1,127 @@ +import torch +from torch.nn import functional as F + +from basicsr.utils.registry import MODEL_REGISTRY +from basicsr.models.sr_model import SRModel + + +@MODEL_REGISTRY.register() +class RGTModel(SRModel): + + def test(self): + self.use_chop = self.opt['val']['use_chop'] if 'use_chop' in self.opt['val'] else False + if not self.use_chop: + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + # test by partitioning + else: + _, C, h, w = self.lq.size() + split_token_h = h // 200 + 1 # number of horizontal cut sections + split_token_w = w // 200 + 1 # number of vertical cut sections + + patch_size_tmp_h = split_token_h + patch_size_tmp_w = split_token_w + + # padding + mod_pad_h, mod_pad_w = 0, 0 + if h % patch_size_tmp_h != 0: + mod_pad_h = patch_size_tmp_h - h % patch_size_tmp_h + if w % patch_size_tmp_w != 0: + mod_pad_w = patch_size_tmp_w - w % patch_size_tmp_w + + img = self.lq + img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, :h+mod_pad_h, :] + img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, :w+mod_pad_w] + + _, _, H, W = img.size() + split_h = H // split_token_h # height of each partition + split_w = W // split_token_w # width of each partition + + # overlapping + shave_h = 16 + shave_w = 16 + scale = self.opt.get('scale', 1) + ral = H // split_h + row = W // split_w + slices = [] # list of partition borders + for i in range(ral): + for j in range(row): + if i == 0 and i == ral - 1: + top = slice(i * split_h, (i + 1) * split_h) + elif i == 0: + top = slice(i*split_h, (i+1)*split_h+shave_h) + elif i == ral - 1: + top = slice(i*split_h-shave_h, (i+1)*split_h) + else: + top = slice(i*split_h-shave_h, (i+1)*split_h+shave_h) + if j == 0 and j == row - 1: + left = slice(j*split_w, (j+1)*split_w) + elif j == 0: + left = slice(j*split_w, (j+1)*split_w+shave_w) + elif j == row - 1: + left = slice(j*split_w-shave_w, (j+1)*split_w) + else: + left = slice(j*split_w-shave_w, (j+1)*split_w+shave_w) + temp = (top, left) + slices.append(temp) + img_chops = [] # list of partitions + for temp in slices: + top, left = temp + img_chops.append(img[..., top, left]) + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + outputs = [] + for chop in img_chops: + out = self.net_g_ema(chop) # image processing of each partition + outputs.append(out) + _img = torch.zeros(1, C, H * scale, W * scale) + # merge + for i in range(ral): + for j in range(row): + top = slice(i * split_h * scale, (i + 1) * split_h * scale) + left = slice(j * split_w * scale, (j + 1) * split_w * scale) + if i == 0: + _top = slice(0, split_h * scale) + else: + _top = slice(shave_h*scale, (shave_h+split_h)*scale) + if j == 0: + _left = slice(0, split_w*scale) + else: + _left = slice(shave_w*scale, (shave_w+split_w)*scale) + _img[..., top, left] = outputs[i * row + j][..., _top, _left] + self.output = _img + else: + self.net_g.eval() + with torch.no_grad(): + outputs = [] + for chop in img_chops: + out = self.net_g(chop) # image processing of each partition + outputs.append(out) + _img = torch.zeros(1, C, H * scale, W * scale) + # merge + for i in range(ral): + for j in range(row): + top = slice(i * split_h * scale, (i + 1) * split_h * scale) + left = slice(j * split_w * scale, (j + 1) * split_w * scale) + if i == 0: + _top = slice(0, split_h * scale) + else: + _top = slice(shave_h * scale, (shave_h + split_h) * scale) + if j == 0: + _left = slice(0, split_w * scale) + else: + _left = slice(shave_w * scale, (shave_w + split_w) * scale) + _img[..., top, left] = outputs[i * row + j][..., _top, _left] + self.output = _img + self.net_g.train() + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/sr_model.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/sr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e076553c25cb21ffc9cb0786cb89a5e9348576ff --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/models/sr_model.py @@ -0,0 +1,235 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.metrics import calculate_metric +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + + +@MODEL_REGISTRY.register() +class SRModel(BaseModel): + """Base SR model for single image super-resolution.""" + + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', False) + # print(with_metrics,use_pbar) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + if with_metrics: + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + # this is img data + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + + metric_data['img'] = sr_img + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + # save img + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + self.metric_results[name] += calculate_metric(metric_data, opt_) + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/test_img.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/test_img.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac2930178acd25aec4b69127f5a3ce3ebac1345 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/test_img.py @@ -0,0 +1,48 @@ +import logging +import torch +from os import path as osp +from basicsr.data import build_dataloader, build_dataset +from basicsr.models import build_model +from basicsr.utils import get_root_logger, get_time_str, make_exp_dirs +from basicsr.utils.options import dict2str, parse_options + + +def image_sr(args): + # parse options, set distributed setting, set ramdom seed + opt, _ = parse_options(args.root_path, is_train=False) + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # create test dataset and dataloader + test_loaders = [] + for _, dataset_opt in sorted(opt['datasets'].items()): + dataset_opt['dataroot_lq'] = osp.join(args.output_dir, f'temp_LR') + if args.SR == 'x4': + opt['upscale'] = opt['network_g']['upscale'] = 4 + opt['val']['suffix'] = 'x4' + opt['path']['pretrain_network_g'] = osp.join(args.root_path, f'experiments/pretrained_models/RGT_x4.pth') + if args.SR == 'x2': + opt['upscale'] = opt['network_g']['upscale'] = 2 + opt['val']['suffix'] = 'x2' + + test_set = build_dataset(dataset_opt) + test_loader = build_dataloader( + test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + test_loaders.append(test_loader) + + opt['path']['pretrain_network_g'] = args.ckpt_path + opt['val']['use_chop'] = args.use_chop + opt['path']['visualization'] = osp.join(args.output_dir, f'temp_results') + opt['path']['results_root'] = osp.join(args.output_dir, f'temp_results') + + # create model + model = build_model(opt) + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + # print(root_path) + # image_sr(root_path) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4860d526119edc892ea348ae212ad4ed65cd0019 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/__init__.py @@ -0,0 +1,30 @@ +from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt + +__all__ = [ + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'AvgTimer', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt', +] diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/dist_util.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/file_client.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..89d83ab9e0d4314f8cdf2393908a561c6d1dca92 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing different lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/img_util.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5f1da0911d9b12f9c6164df6c6e14e3c1aef88 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/img_util.py @@ -0,0 +1,172 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + ok = cv2.imwrite(file_path, img, params) + if not ok: + raise IOError('Failed in writing images.') + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/logger.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..73553dc664781a061737e94880ea1c6788c09043 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/logger.py @@ -0,0 +1,213 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class AvgTimer(): + + def __init__(self, window=200): + self.window = window # average window + self.current_time = 0 + self.total_time = 0 + self.count = 0 + self.avg_time = 0 + self.start() + + def start(self): + self.start_time = self.tic = time.time() + + def record(self): + self.count += 1 + self.toc = time.time() + self.current_time = self.toc - self.tic + self.total_time += self.current_time + # calculate average time + self.avg_time = self.total_time / self.count + + # reset + if self.count > self.window: + self.count = 0 + self.total_time = 0 + + self.tic = time.time() + + def get_current_time(self): + return self.current_time + + def get_avg_time(self): + return self.avg_time + + +class MessageLogger(): + """Message logger for printing. + + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + def reset_start_time(self): + self.start_time = time.time() + + @master_only + def __call__(self, log_vars): + """Format logging message. + + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger and 'debug' not in self.exp_name: + if k.startswith('l_'): + self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = get_root_logger() + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + + Currently, only log the software version. + """ + import torch + import torchvision + + from basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/matlab_functions.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/matlab_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f1a83bc8beee468dd7c9ca734966e926fd9fde --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/matlab_functions.py @@ -0,0 +1,359 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + squeeze_flag = False + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + if img.ndim == 2: + img = img[:, :, None] + squeeze_flag = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + if img.ndim == 2: + img = img.unsqueeze(0) + squeeze_flag = True + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if squeeze_flag: + out_2 = out_2.squeeze(0) + if numpy_type: + out_2 = out_2.numpy() + if not squeeze_flag: + out_2 = out_2.transpose(1, 2, 0) + + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/misc.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..728fef857d0071875c82ffcbc8c74b6fbe029e22 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/misc.py @@ -0,0 +1,141 @@ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): + continue + else: + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + print('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (network + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + print(f"Set {name} to {opt['path'][name]}") + + # change param_key to params in resume + param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] + for param_key in param_keys: + if opt['path'][param_key] == 'params_ema': + opt['path'][param_key] = 'params' + print(f'Set {param_key} to params') + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formatted file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/options.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/options.py new file mode 100644 index 0000000000000000000000000000000000000000..1925644dd62fc7b6b1e47bf2641f9d11251f3142 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/options.py @@ -0,0 +1,200 @@ +import argparse +import random +import torch +import yaml +from collections import OrderedDict +from os import path as osp + +from basicsr.utils import set_random_seed +from basicsr.utils.dist_util import get_dist_info, init_dist, master_only + + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg + + +def _postprocess_yml_value(value): + # None + if value == '~' or value.lower() == 'none': + return None + # bool + if value.lower() == 'true': + return True + elif value.lower() == 'false': + return False + # !!float number + if value.startswith('!!float'): + return float(value.replace('!!float', '')) + # number + if value.isdigit(): + return int(value) + elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: + return float(value) + # list + if value.startswith('['): + return eval(value) + # str + return value + + +def parse_options(root_path, SR, is_train=True): + parser = argparse.ArgumentParser() + # parser.add_argument('-opt', type=str, default = 'options/test/test_RGT_S_x2.yml',required=True, help='Path to option YAML file.') + if SR == 'x4': + file_path = osp.join(root_path,'options/test/test_RGT_x4.yml') + if SR == 'x2': + file_path = osp.join(root_path,'options/test/test_RGT_x2.yml') + parser.add_argument('-opt', type=str, default = file_path, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--auto_resume', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') + args = parser.parse_args() + + # parse yml to dict + with open(args.opt, mode='r') as f: + # print(args.opt) + opt = yaml.load(f, Loader=ordered_yaml()[0]) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + # force to update yml options + if args.force_yml is not None: + for entry in args.force_yml: + # now do not support creating new keys + keys, value = entry.split('=') + keys, value = keys.strip(), value.strip() + value = _postprocess_yml_value(value) + eval_str = 'opt' + for key in keys.split(':'): + eval_str += f'["{key}"]' + eval_str += '=value' + # using exec function + exec(eval_str) + + opt['auto_resume'] = args.auto_resume + opt['is_train'] = is_train + + # debug setting + if args.debug and not opt['name'].startswith('debug'): + opt['name'] = 'debug_' + opt['name'] + + if opt['num_gpu'] == 'auto': + opt['num_gpu'] = torch.cuda.device_count() + + # datasets + for phase, dataset in opt['datasets'].items(): + # for multiple datasets, e.g., val_1, val_2; test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + # change some options for debug mode + if 'debug' in opt['name']: + if 'val' in opt: + opt['val']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt, args + + +@master_only +def copy_opt_file(opt_file, experiments_root): + # copy the yml file to the experiment root + import sys + import time + from shutil import copyfile + cmd = ' '.join(sys.argv) + filename = osp.join(experiments_root, osp.basename(opt_file)) + copyfile(opt_file, filename) + + with open(filename, 'r+') as f: + lines = f.readlines() + lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') + f.seek(0) + f.writelines(lines) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/registry.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..655753b3b9cbd0cfe73fe93a77cf1fcc3db6d827 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/basicsr/utils/registry.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj): + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj) + + def get(self, name): + ret = self._obj_map.get(name) + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/test_RGT_x2.yml b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/test_RGT_x2.yml new file mode 100644 index 0000000000000000000000000000000000000000..4f88a04d696328ce21641fc733900cd2bb4bf262 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/test_RGT_x2.yml @@ -0,0 +1,94 @@ +# general settings +name: test_RGT_x2 +model_type: RGTModel +scale: 2 +num_gpu: 1 +manual_seed: 10 + +datasets: + test_1: # the 1st test dataset + task: SR + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2 + filename_tmpl: '{}x2' + io_backend: + type: disk + + # test_2: # the 2st test dataset + # task: SR + # name: Set14 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Set14/HR + # dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2 + # filename_tmpl: '{}x2' + # io_backend: + # type: disk + + # test_3: # the 3st test dataset + # task: SR + # name: B100 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/B100/HR + # dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2 + # filename_tmpl: '{}x2' + # io_backend: + # type: disk + + # test_4: # the 4st test dataset + # task: SR + # name: Urban100 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Urban100/HR + # dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2 + # filename_tmpl: '{}x2' + # io_backend: + # type: disk + + # test_5: # the 5st test dataset + # task: SR + # name: Manga109 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Manga109/HR + # dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2 + # filename_tmpl: '{}_LRBI_x2' + # io_backend: + # type: disk + + +# network structures +network_g: + type: RGT + upscale: 2 + in_chans: 3 + img_size: 64 + img_range: 1. + depth: [6,6,6,6,6,6,6,6] + embed_dim: 180 + num_heads: [6,6,6,6,6,6,6,6] + mlp_ratio: 2 + resi_connection: '1conv' + split_size: [8,32] + c_ratio: 0.5 + +# path +path: + pretrain_network_g: /remote-home/lzy/RGT/experiments/pretrained_models/RGT_x2.pth + strict_load_g: True + +# validation settings +val: + save_img: True + suffix: ~ # add suffix to saved images, if None, use exp name + use_chop: False # True to save memory, if img too large + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 2 + test_y_channel: True + ssim: + type: calculate_ssim + crop_border: 2 + test_y_channel: True \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/test_RGT_x4.yml b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/test_RGT_x4.yml new file mode 100644 index 0000000000000000000000000000000000000000..a776fa5f82b1b5f8c57241c060219c1489a14a2f --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/test_RGT_x4.yml @@ -0,0 +1,94 @@ +# general settings +name: test_RGT_x4 +model_type: RGTModel +scale: 4 +num_gpu: 1 +manual_seed: 10 + +datasets: + test_1: # the 1st test dataset + task: SR + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + + # test_2: # the 2st test dataset + # task: SR + # name: Set14 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Set14/HR + # dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 + # filename_tmpl: '{}x4' + # io_backend: + # type: disk + + # test_3: # the 3st test dataset + # task: SR + # name: B100 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/B100/HR + # dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 + # filename_tmpl: '{}x4' + # io_backend: + # type: disk + + # test_4: # the 4st test dataset + # task: SR + # name: Urban100 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Urban100/HR + # dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 + # filename_tmpl: '{}x4' + # io_backend: + # type: disk + + # test_5: # the 5st test dataset + # task: SR + # name: Manga109 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Manga109/HR + # dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4 + # filename_tmpl: '{}_LRBI_x4' + # io_backend: + # type: disk + + +# network structures +network_g: + type: RGT + upscale: 4 + in_chans: 3 + img_size: 64 + img_range: 1. + depth: [6,6,6,6,6,6,6,6] + embed_dim: 180 + num_heads: [6,6,6,6,6,6,6,6] + mlp_ratio: 2 + resi_connection: '1conv' + split_size: [8,32] + c_ratio: 0.5 + +# path +path: + pretrain_network_g: /remote-home/lzy/RGT/experiments/pretrained_models/RGT_x4.pth + strict_load_g: True + +# validation settings +val: + save_img: True + suffix: ~ # add suffix to saved images, if None, use exp name + use_chop: False # True to save memory, if img too large + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: True + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: True \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/test_single_config.yml b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/test_single_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..58a0753db579a1ebc6b5e6382c6e091e098880eb --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/options/test/test_single_config.yml @@ -0,0 +1,41 @@ +# general settings +name: test_single +model_type: RGTModel +scale: 2 +num_gpu: 1 +manual_seed: 10 + +datasets: + test_1: # the 1st test dataset + name: Single + type: SingleImageDataset + dataroot_lq: /test + io_backend: + type: disk + + +# network structures +network_g: + type: RGT + upscale: 2 + in_chans: 3 + img_size: 64 + img_range: 1. + depth: [6,6,6,6,6,6,6,6] + embed_dim: 180 + num_heads: [6,6,6,6,6,6,6,6] + mlp_ratio: 2 + resi_connection: '1conv' + split_size: [8,32] + c_ratio: 0.5 + +# path +path: + pretrain_network_g: /test + strict_load_g: True + +# validation settings +val: + save_img: True + suffix: ~ # add suffix to saved images, if None, use exp name + use_chop: False # True to save memory, if img too large \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/run.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/run.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8e2bfcb38196c2262a96d270282d547f95f488 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/super_resolution/run.py @@ -0,0 +1,138 @@ +import cv2 +import argparse +from basicsr.test_img import image_sr +from os import path as osp +import os +import shutil +from PIL import Image +import re +import imageio.v2 as imageio +import threading +from concurrent.futures import ThreadPoolExecutor +import time + +def replace_filename(original_path, suffix): + + directory = os.path.dirname(original_path) + old_filename = os.path.basename(original_path) + name_part, file_extension = os.path.splitext(old_filename) + new_filename = f"{name_part}{suffix}{file_extension}" + new_path = os.path.join(directory, new_filename) + + return new_path + +def create_temp_folder(folder_path): + + if os.path.exists(folder_path): + shutil.rmtree(folder_path) + os.makedirs(folder_path) + +def delete_temp_folder(folder_path): + shutil.rmtree(folder_path) + +def extract_number(filename): + s = re.findall(r'\d+', filename) + return int(s[0]) if s else -1 + +def bicubic_upsample_opencv(input_image_path, output_image_path, scale_factor): + + img = cv2.imread(input_image_path) + + original_height, original_width = img.shape[:2] + + new_width = int(original_width * scale_factor) + new_height = int(original_height * scale_factor) + + upsampled_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC) + cv2.imwrite(output_image_path, upsampled_img) + + +def process_frame(frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, SR): + frame_path = os.path.join(temp_LR_folder_path, f"frame_{frame_count}{SR}.png") + cv2.imwrite(frame_path, frame) + HR_frame_path = os.path.join(temp_HR_folder_path, f"frame_{frame_count}.png") + + if SR == 'x4': + bicubic_upsample_opencv(frame_path, HR_frame_path, 4) + elif SR == 'x2': + bicubic_upsample_opencv(frame_path, HR_frame_path, 2) + +def video_sr(args): + file_name = os.path.basename(args.input_dir) + video_output_path = os.path.join(args.output_dir,file_name) + + if args.SR == 'x4': + temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X4') + video_output_path = replace_filename(video_output_path, '_x4') + result_temp = osp.join(args.root_path, f'results/test_RGT_x4/visualization/Set5') + if args.SR == 'x2': + temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X2') + video_output_path = replace_filename(video_output_path, '_x2') + result_temp = osp.join(args.root_path, f'results/test_RGT_x2/visualization/Set5') + + temp_HR_folder_path = os.path.join(args.output_dir, f'temp_HR') + + # create_temp_folder(result_temp) + create_temp_folder(temp_LR_folder_path) + create_temp_folder(temp_HR_folder_path) + + cap = cv2.VideoCapture(args.input_dir) + if not cap.isOpened(): + print("Error opening video file.") + return + + t1 = time.time() + frame_count = 0 + frames_to_process = [] + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frames_to_process.append((frame_count, frame)) + frame_count += 1 + + with ThreadPoolExecutor(max_workers = args.mul_numwork) as executor: + for frame_count, frame in frames_to_process: + executor.submit(process_frame, frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, args.SR) + + print("total frames:",frame_count) + print("fps :",cap.get(cv2.CAP_PROP_FPS)) + + t2 = time.time() + print('mul threads: ',t2 - t1,'s') + # progress all frames in video + image_sr(args) + + t3 = time.time() + print('image super resolution: ',t3 - t2,'s') + # recover video form all frames + frame_files = sorted(os.listdir(result_temp), key=extract_number) + video_frames = [imageio.imread(os.path.join(result_temp, frame_file)) for frame_file in frame_files] + fps = cap.get(cv2.CAP_PROP_FPS) + imageio.mimwrite(video_output_path, video_frames, fps=fps, quality=9) + + t4 = time.time() + print('tranformer frames to video: ',t4 - t3,'s') + # release all resources + cap.release() + delete_temp_folder(os.path.dirname(temp_LR_folder_path)) + delete_temp_folder(temp_HR_folder_path) + delete_temp_folder(os.path.join(args.root_path, f'results')) + + t5 = time.time() + print('delete time: ',t5 - t4,'s') + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="RGT for Video Super-Resolution") + # make sure you SR is match with the ckpt_path + parser.add_argument("--SR", type=str, choices=['x2', 'x4'], default='x4', help='image resolution') + parser.add_argument("--ckpt_path", type=str, default = "/remote-home/lzy/RGT/experiments/pretrained_models/RGT_x4.pth") + + parser.add_argument("--root_path", type=str, default = "/remote-home/lzy/RGT") + parser.add_argument("--input_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video/video_test1.mp4") + parser.add_argument("--output_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video_output") + + parser.add_argument("--mul_numwork", type=int, default = 16, help ='max_workers to execute Multi') + parser.add_argument("--use_chop", type= bool, default = True, help ='use_chop: True # True to save memory, if img too large') + args = parser.parse_args() + video_sr(args) diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/text_encoder/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/text_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7169e00423ff8ca4a36bc31a4dea08043f2fc9e6 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/text_encoder/__init__.py @@ -0,0 +1,48 @@ +import torch +from torch import nn +from transformers import T5EncoderModel, CLIPModel, CLIPProcessor + +from videogen_hub.pipelines.opensora_plan.opensora.utils.utils import get_precision + + +class T5Wrapper(nn.Module): + def __init__(self, args, **kwargs): + super(T5Wrapper, self).__init__() + self.model_name = args.text_encoder_name + self.text_enc = T5EncoderModel.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval() + + def forward(self, input_ids, attention_mask): + text_encoder_embs = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state'] + return text_encoder_embs.detach() + +class CLIPWrapper(nn.Module): + def __init__(self, args): + super(CLIPWrapper, self).__init__() + self.model_name = args.text_encoder_name + dtype = get_precision(args) + model_kwargs = {'cache_dir': args.cache_dir, 'low_cpu_mem_usage': True, 'torch_dtype': dtype} + self.text_enc = CLIPModel.from_pretrained(self.model_name, **model_kwargs).eval() + + def forward(self, input_ids, attention_mask): + text_encoder_embs = self.text_enc.get_text_features(input_ids=input_ids, attention_mask=attention_mask) + return text_encoder_embs.detach() + + + +text_encoder = { + 'DeepFloyd/t5-v1_1-xxl': T5Wrapper, + 'openai/clip-vit-large-patch14': CLIPWrapper +} + + +def get_text_enc(args): + """deprecation""" + text_enc = text_encoder.get(args.text_encoder_name, None) + assert text_enc is not None + return text_enc(args) + +def get_text_warpper(text_encoder_name): + """deprecation""" + text_enc = text_encoder.get(text_encoder_name, None) + assert text_enc is not None + return text_enc diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/text_encoder/clip.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/text_encoder/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..11bf4041e217858eef767b913921d1b173e28bd4 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/text_encoder/clip.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- +import os +import re +import ftfy +import torch +import html +from PIL import Image +from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, CLIPTextModel + +class CLIPEmbedder: + """ + A class for embedding texts and images using a pretrained CLIP model. + """ + + def __init__(self, device='cuda', model_name='openai/clip-vit-base-patch32', cache_dir='./cache_dir', use_text_preprocessing=True, max_length=77): + """ + Initializes the CLIPEmbedder with specified model and configurations. + """ + self.device = torch.device(device) + self.model_name = model_name + self.cache_dir = cache_dir + self.use_text_preprocessing = use_text_preprocessing + self.max_length = max_length + + os.makedirs(self.cache_dir, exist_ok=True) + + self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=self.cache_dir) + self.model = CLIPModel.from_pretrained(model_name, cache_dir=self.cache_dir).to(self.device).eval() + self.tokenizer = CLIPTokenizer.from_pretrained(model_name) + self.text_model = CLIPTextModel.from_pretrained(model_name, cache_dir=self.cache_dir).to(self.device).eval() + + for param in self.text_model.parameters(): + param.requires_grad = False + + def get_text_embeddings(self, texts): + """ + Generates embeddings for a list of text prompts. + """ + self._validate_input_list(texts, str) + + if self.use_text_preprocessing: + texts = [self._clean_text(text) for text in texts] + + inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device) + + with torch.no_grad(): + embeddings = self.model.get_text_features(**inputs) + + return embeddings + + def encode_text(self, texts): + """ + Encodes texts into embeddings and returns the last hidden state and pooled output. + """ + self._validate_input_list(texts, str) + + batch_encoding = self.tokenizer(texts, return_tensors="pt", truncation=True, max_length=self.max_length, padding="max_length").to(self.device) + + with torch.no_grad(): + outputs = self.text_model(**batch_encoding) + + return outputs.last_hidden_state, outputs.pooler_output + + def get_image_embeddings(self, image_paths): + """ + Generates embeddings for a list of image file paths. + """ + self._validate_input_list(image_paths, str) + images = [self._load_image(path) for path in image_paths] + + inputs = self.processor(images=images, return_tensors="pt").to(self.device) + + with torch.no_grad(): + embeddings = self.model.get_image_features(**inputs) + + return embeddings + + def _validate_input_list(self, input_list, expected_type): + """ + Validates that the input is a list of expected type. + """ + if not isinstance(input_list, list) or not all(isinstance(item, expected_type) for item in input_list): + raise ValueError(f"Input must be a list of {expected_type.__name__}.") + + def _clean_text(self, text): + """ + Applies basic cleaning and formatting to a text string. + """ + text = ftfy.fix_text(text) + text = html.unescape(text) + return text.strip() + + def _load_image(self, image_path): + """ + Loads and preprocesses an image from a file path. + """ + try: + image = Image.open(image_path).convert("RGB") + except FileNotFoundError: + raise FileNotFoundError(f"Image file not found: {image_path}") + except Exception as e: + raise Exception(f"Error loading image {image_path}: {e}") + return image + + def clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub('', 'person', caption) + # urls: + caption = re.sub( + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + caption = re.sub( + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + + caption = BeautifulSoup(caption, features='html.parser').text + + + caption = re.sub(r'@[\w\d]+\b', '', caption) + + caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) + caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) + caption = re.sub(r'[\u3200-\u32ff]+', '', caption) + caption = re.sub(r'[\u3300-\u33ff]+', '', caption) + caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) + caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) + caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) + + caption = re.sub( + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + '-', caption) + + + caption = re.sub(r'[`´«»“”¨]', '"', caption) + caption = re.sub(r'[‘’]', "'", caption) + + + caption = re.sub(r'"?', '', caption) + + caption = re.sub(r'&', '', caption) + + + caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) + + + caption = re.sub(r'\d:\d\d\s+$', '', caption) + + + caption = re.sub(r'\\n', ' ', caption) + + + caption = re.sub(r'#\d{1,3}\b', '', caption) + + caption = re.sub(r'#\d{5,}\b', '', caption) + caption = re.sub(r'\b\d{6,}\b', '', caption) + caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) + caption = re.sub(r'[\"\']{2,}', r'"', caption) + caption = re.sub(r'[\.]{2,}', r' ', caption) + + caption = re.sub(self.bad_punct_regex, r' ', caption) + caption = re.sub(r'\s+\.\s+', r' ', caption) + regex2 = re.compile(r'(?:\-|\_)') + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, ' ', caption) + caption = self.basic_clean(caption) + caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 + caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc + caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 + + caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) + caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) + caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) + caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) + caption = re.sub(r'\bpage\s+\d+\b', '', caption) + + caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... + + caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) + + caption = re.sub(r'\b\s+\:\s+', r': ', caption) + caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) + caption = re.sub(r'\s+', ' ', caption) + + caption.strip() + + caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) + caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) + caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) + caption = re.sub(r'^\.\S+$', '', caption) + + return caption.strip() + + @staticmethod + def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + +if __name__ == '__main__': + + clip_embedder = CLIPEmbedder() + + # Example + text_prompts = [ + 'A photo of a cute puppy playing with a ball.', + 'An image of a beautiful sunset over the ocean.', + 'A scene depicting a busy city street.' + ] + text_embeddings = clip_embedder.get_text_embeddings(text_prompts) + print(f"Text embeddings shape: {text_embeddings.shape}") + + image_paths = ['image1.jpg', 'image2.png'] + try: + image_embeddings = clip_embedder.get_image_embeddings(image_paths) + print(f"Image embeddings shape: {image_embeddings.shape}") + except FileNotFoundError as e: + print(e) + except Exception as e: + print(f"An error occurred: {e}") + diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/models/text_encoder/t5.py b/src/videogen_hub/pipelines/opensora_plan/opensora/models/text_encoder/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc782072e58d9fd3041d36a2a1ded089b61329d --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/models/text_encoder/t5.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +import os +import re +import html +import urllib.parse as ul + +import ftfy +import torch +from bs4 import BeautifulSoup +from transformers import T5EncoderModel, AutoTokenizer +from huggingface_hub import hf_hub_download + +class T5Embedder: + + available_models = ['t5-v1_1-xxl'] + bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa + + def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir='./cache_dir', hf_token=None, use_text_preprocessing=True, + t5_model_kwargs=None, torch_dtype=None, model_max_length=120): + self.device = torch.device(device) + self.torch_dtype = torch_dtype or torch.bfloat16 + if t5_model_kwargs is None: + t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} + t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} + + self.use_text_preprocessing = use_text_preprocessing + self.hf_token = hf_token + self.cache_dir = cache_dir + self.dir_or_name = dir_or_name + cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') + for filename in ['config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', + 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin', 'pytorch_model.bin.index.json']: + hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, + force_filename=filename, token=self.hf_token) + + print(cache_dir) + self.tokenizer = AutoTokenizer.from_pretrained(cache_dir) + self.model = T5EncoderModel.from_pretrained(cache_dir, **t5_model_kwargs).eval() + self.model_max_length = model_max_length + + def get_text_embeddings(self, texts): + texts = [self.text_preprocessing(text) for text in texts] + + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.model_max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + + text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] + text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] + + with torch.no_grad(): + text_encoder_embs = self.model( + input_ids=text_tokens_and_mask['input_ids'].to(self.device), + attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), + )['last_hidden_state'].detach() + return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device) + + def text_preprocessing(self, text): + if self.use_text_preprocessing: + # The exact text cleaning as was in the training stage: + text = self.clean_caption(text) + text = self.clean_caption(text) + return text + else: + return text.lower().strip() + + @staticmethod + def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + def clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub('', 'person', caption) + # urls: + caption = re.sub( + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + caption = re.sub( + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + # html: + caption = BeautifulSoup(caption, features='html.parser').text + + # @ + caption = re.sub(r'@[\w\d]+\b', '', caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) + caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) + caption = re.sub(r'[\u3200-\u32ff]+', '', caption) + caption = re.sub(r'[\u3300-\u33ff]+', '', caption) + caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) + caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) + caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + '-', caption) + + # кавычки к одному стандарту + caption = re.sub(r'[`´«»“”¨]', '"', caption) + caption = re.sub(r'[‘’]', "'", caption) + + # " + caption = re.sub(r'"?', '', caption) + # & + caption = re.sub(r'&', '', caption) + + # ip adresses: + caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) + + # article ids: + caption = re.sub(r'\d:\d\d\s+$', '', caption) + + # \n + caption = re.sub(r'\\n', ' ', caption) + + # "#123" + caption = re.sub(r'#\d{1,3}\b', '', caption) + # "#12345.." + caption = re.sub(r'#\d{5,}\b', '', caption) + # "123456.." + caption = re.sub(r'\b\d{6,}\b', '', caption) + # filenames: + caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) + + # + caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r'(?:\-|\_)') + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, ' ', caption) + + caption = self.basic_clean(caption) + + caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 + caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc + caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 + + caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) + caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) + caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) + caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) + caption = re.sub(r'\bpage\s+\d+\b', '', caption) + + caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... + + caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) + + caption = re.sub(r'\b\s+\:\s+', r': ', caption) + caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) + caption = re.sub(r'\s+', ' ', caption) + + caption.strip() + + caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) + caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) + caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) + caption = re.sub(r'^\.\S+$', '', caption) + + return caption.strip() + +if __name__ == '__main__': + t5 = T5Embedder(device="cuda", cache_dir='./cache_dir', torch_dtype=torch.float) + device = t5.device + prompts = ['I am a test caption', 'Test twice'] + with torch.no_grad(): + caption_embs, emb_masks = t5.get_text_embeddings(prompts) + emb_dict = { + 'caption_feature': caption_embs.float().cpu().data.numpy(), + 'attention_mask': emb_masks.cpu().data.numpy(), + } + import ipdb;ipdb.set_trace() + print() \ No newline at end of file diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/pipeline_videogen.py b/src/videogen_hub/pipelines/opensora_plan/opensora/pipeline_videogen.py new file mode 100644 index 0000000000000000000000000000000000000000..42334d85a03e04cadea7d5e4038b22f7cf0317a8 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/pipeline_videogen.py @@ -0,0 +1,758 @@ +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import html +import inspect +import math +import re +import urllib.parse as ul +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple, Union + +import torch +from diffusers.models import AutoencoderKL, Transformer2DModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import DPMSolverMultistepScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +from transformers import T5EncoderModel, T5Tokenizer + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtAlphaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. + >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + + +@dataclass +class VideoPipelineOutput(BaseOutput): + video: torch.Tensor + + +class VideoGenPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Alpha. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: Transformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py + def mask_text_embeddings(self, emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096 + else: + masked_feature = emb * mask[:, None, :, None] # 1 120 4096 + return masked_feature, emb.shape[2] + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + mask_feature: bool = True, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + mask_feature: (bool, defaults to `True`): + If `True`, the function will mask the text embeddings. + """ + embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None + + if device is None: + device = self.text_encoder.device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = 300 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + logger.warning( + "The following part of your input was truncated because the model can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds_attention_mask = attention_mask + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds_attention_mask = torch.ones_like(prompt_embeds) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + # print(prompt_embeds.shape) # 1 120 4096 + # print(negative_prompt_embeds.shape) # 1 120 4096 + + # Perform additional masking. + if mask_feature and not embeds_initially_provided: + prompt_embeds = prompt_embeds.unsqueeze(1) + masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask) + masked_prompt_embeds = masked_prompt_embeds.squeeze(1) + masked_negative_prompt_embeds = ( + negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None + ) + + # import torch.nn.functional as F + + # padding = (0, 0, 0, 113) # (左, 右, 下, 上) + # masked_prompt_embeds_ = F.pad(masked_prompt_embeds, padding, "constant", 0) + # masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0) + + # print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...]) + + return masked_prompt_embeds, masked_negative_prompt_embeds + # return masked_prompt_embeds_, masked_negative_prompt_embeds_ + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, + latents=None): + shape = ( + batch_size, + num_channels_latents, + (math.ceil((int(num_frames) - 1) / self.vae.vae_scale_factor[0]) + 1) if int(num_frames) % 2 == 1 else math.ceil(int(num_frames) / self.vae.vae_scale_factor[0]), + math.ceil(int(height) / self.vae.vae_scale_factor[1]), + math.ceil(int(width) / self.vae.vae_scale_factor[2]), + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + mask_feature: bool = True, + enable_temporal_attentions: bool = True, + ) -> Union[VideoPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + # height = height or self.transformer.config.sample_size * self.vae_scale_factor + # width = width or self.transformer.config.sample_size * self.vae_scale_factor + self.check_inputs( + prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self.text_encoder.device or self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + mask_feature=mask_feature, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # if self.transformer.config.sample_size == 128: + # resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + # aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + # resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + # aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + # added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + # if prompt_attention_mask.ndim == 2: + # prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + enable_temporal_attentions=enable_temporal_attentions, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == 'latents': + video = self.decode_latents(latents) + video = video[:, :num_frames, :height, :width] + else: + video = latents + return VideoPipelineOutput(video=video) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return VideoPipelineOutput(video=video) + + def decode_latents(self, latents): + video = self.vae.decode(latents) # b t c h w + # b t c h w -> b t h w c + video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() + return video diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/sample_t2v.py b/src/videogen_hub/pipelines/opensora_plan/opensora/sample_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..06ca056e6210a6e35b8a56307d5c1fea9401efe1 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/sample_t2v.py @@ -0,0 +1,170 @@ +import argparse +import math +import os +import sys + +import torch +from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, + EulerDiscreteScheduler, DPMSolverMultistepScheduler, + HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, + DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) +from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler +from torchvision.utils import save_image +from transformers import T5EncoderModel, T5Tokenizer + +from .models.ae import ae_stride_config, getae_wrapper +from .models.diffusion.latte.modeling_latte import LatteT2V +from .utils.utils import save_video_grid + +sys.path.append(os.path.split(sys.path[0])[0]) +from pipeline_videogen import VideoGenPipeline + +import imageio + + +def parse_args(arg_list): + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') + parser.add_argument("--version", type=str, default=None, choices=[None, '65x512x512', '221x512x512', '513x512x512']) + parser.add_argument("--num_frames", type=int, default=1) + parser.add_argument("--height", type=int, default=512) + parser.add_argument("--width", type=int, default=512) + parser.add_argument("--cache_dir", type=str, default='./cache_dir') + parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') + parser.add_argument("--ae_path", type=str, default='CausalVAEModel_4x8x8') + parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') + parser.add_argument("--save_img_path", type=str, + default="./src/videogen_hub/pipelines/opensora_plan/sample_videos/t2v") + parser.add_argument("--guidance_scale", type=float, default=7.5) + parser.add_argument("--sample_method", type=str, default="PNDM") + parser.add_argument("--num_sampling_steps", type=int, default=50) + parser.add_argument("--fps", type=int, default=24) + parser.add_argument("--run_time", type=int, default=0) + parser.add_argument("--text_prompt", nargs='+') + parser.add_argument('--force_images', action='store_true') + parser.add_argument('--tile_overlap_factor', type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + args = parser.parse_args(arg_list) + return args + + +class OpenSoraPlanPipeline(): + def __init__(self, arg_list, device): + self.args = parse_args(arg_list) + self.device = device + + def inference(self, save_output): + # torch.manual_seed(args.seed) + args = self.args + torch.set_grad_enabled(False) + device = self.device + + vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir=args.cache_dir).to(device, + dtype=torch.float16) + # vae = getae_wrapper(args.ae)(args.ae_path).to(device, dtype=torch.float16) + if args.enable_tiling: + vae.vae.enable_tiling() + vae.vae.tile_overlap_factor = args.tile_overlap_factor + vae.vae_scale_factor = ae_stride_config[args.ae] + # Load model: + transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, cache_dir=args.cache_dir, + torch_dtype=torch.float16).to(device) + # transformer_model = LatteT2V.from_pretrained(args.model_path, low_cpu_mem_usage=False, device_map=None, torch_dtype=torch.float16).to(device) + + transformer_model.force_images = args.force_images + tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir) + text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir, + torch_dtype=torch.float16).to(device) + + if args.force_images: + ext = 'jpg' + else: + ext = 'mp4' + + # set eval mode + transformer_model.eval() + vae.eval() + text_encoder.eval() + + if args.sample_method == 'DDIM': ######### + scheduler = DDIMScheduler() + elif args.sample_method == 'EulerDiscrete': + scheduler = EulerDiscreteScheduler() + elif args.sample_method == 'DDPM': ############# + scheduler = DDPMScheduler() + elif args.sample_method == 'DPMSolverMultistep': + scheduler = DPMSolverMultistepScheduler() + elif args.sample_method == 'DPMSolverSinglestep': + scheduler = DPMSolverSinglestepScheduler() + elif args.sample_method == 'PNDM': + scheduler = PNDMScheduler() + elif args.sample_method == 'HeunDiscrete': ######## + scheduler = HeunDiscreteScheduler() + elif args.sample_method == 'EulerAncestralDiscrete': + scheduler = EulerAncestralDiscreteScheduler() + elif args.sample_method == 'DEISMultistep': + scheduler = DEISMultistepScheduler() + elif args.sample_method == 'KDPM2AncestralDiscrete': ######### + scheduler = KDPM2AncestralDiscreteScheduler() + print('videogen_pipeline', device) + videogen_pipeline = VideoGenPipeline(vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer_model).to(device=device) + # videogen_pipeline.enable_xformers_memory_efficient_attention() + + if not os.path.exists(args.save_img_path): + os.makedirs(args.save_img_path) + + video_grids = [] + if not isinstance(args.text_prompt, list): + args.text_prompt = [args.text_prompt] + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): + text_prompt = open(args.text_prompt[0], 'r').readlines() + args.text_prompt = [i.strip() for i in text_prompt] + for idx, prompt in enumerate(args.text_prompt): + print('Processing the ({}) prompt'.format(prompt)) + videos = videogen_pipeline(prompt, + num_frames=args.num_frames, + height=args.height, + width=args.width, + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale, + enable_temporal_attentions=not args.force_images, + num_images_per_prompt=1, + mask_feature=True, + ).video + print(videos.shape) + try: + if args.force_images: + videos = videos[:, 0].permute(0, 3, 1, 2) # b t h w c -> b c h w + save_image(videos / 255.0, os.path.join(args.save_img_path, f'{idx}.{ext}'), + nrow=1, normalize=True, value_range=(0, 1)) # t c h w + + elif save_output: + imageio.mimwrite( + os.path.join( + args.save_img_path, f'{idx}.{ext}'), videos[0], + fps=args.fps, quality=9) # highest quality is 10, lowest is 0 + except: + print('Error when saving {}'.format(prompt)) + video_grids.append(videos) + video_grids = torch.cat(video_grids, dim=0) + + # torchvision.io.write_video(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6) + if args.force_images: + save_image(video_grids / 255.0, os.path.join(args.save_img_path, + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), + nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1)) + print('save path {}'.format(args.save_img_path)) + elif save_output: + video_grids = save_video_grid(video_grids) + imageio.mimwrite(os.path.join(args.save_img_path, + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), + video_grids, fps=args.fps, quality=9) + print('save path {}'.format(args.save_img_path)) + else: + return video_grids + + # save_videos_grid(video, f"./{prompt}.gif") diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/utils/__init__.py b/src/videogen_hub/pipelines/opensora_plan/opensora/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/utils/dataset_utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/utils/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac623fe4d4d0c1414ff55dcd04909eb0493c3ecd --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/utils/dataset_utils.py @@ -0,0 +1,37 @@ +import decord + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'] + + +class DecordInit(object): + """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" + + def __init__(self, num_threads=1): + self.num_threads = num_threads + self.ctx = decord.cpu(0) + + def __call__(self, filename): + """Perform the Decord initialization. + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + reader = decord.VideoReader(filename, + ctx=self.ctx, + num_threads=self.num_threads) + return reader + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'sr={self.sr},' + f'num_threads={self.num_threads})') + return repr_str + + +def pad_to_multiple(number, ds_stride): + remainder = number % ds_stride + if remainder == 0: + return number + else: + padding = ds_stride - remainder + return number + padding diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/utils/downloader.py b/src/videogen_hub/pipelines/opensora_plan/opensora/utils/downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..f2ac4017b10033f67b5affce0906c0defadb38cf --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/utils/downloader.py @@ -0,0 +1,18 @@ +import gdown +import os + +opensora_cache_home = os.path.expanduser( + os.getenv("OPENSORA_HOME", os.path.join("~/.cache", "opensora")) +) + + +def gdown_download(id, fname, cache_dir=None): + cache_dir = opensora_cache_home if not cache_dir else cache_dir + + os.makedirs(cache_dir, exist_ok=True) + destination = os.path.join(cache_dir, fname) + if os.path.exists(destination): + return destination + + gdown.download(id=id, output=destination, quiet=False) + return destination diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/utils/taming_download.py b/src/videogen_hub/pipelines/opensora_plan/opensora/utils/taming_download.py new file mode 100644 index 0000000000000000000000000000000000000000..5a62be7781ae7eb166d5a80383f30a83af093895 --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/utils/taming_download.py @@ -0,0 +1,145 @@ +"""Modified from https://github.com/CompVis/taming-transformers.git""" + +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + diff --git a/src/videogen_hub/pipelines/opensora_plan/opensora/utils/utils.py b/src/videogen_hub/pipelines/opensora_plan/opensora/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a10ce960e019111ef603eca60470eec761282eb --- /dev/null +++ b/src/videogen_hub/pipelines/opensora_plan/opensora/utils/utils.py @@ -0,0 +1,170 @@ +import html +import math +import re +import urllib.parse as ul +from typing import Union, Iterable + +import torch +from diffusers.utils import is_bs4_available, is_ftfy_available + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +bad_punct_regex = re.compile( + r'[' + '#®•©™&@·º½¾¿¡§~' + '\)' + '\(' + '\]' + '\[' + '\}' + '\{' + '\|' + '\\' + '\/' + '\*' + r']{1,}') # noqa + +_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] + + +################################################################################# +# Training Clip Gradients # +################################################################################# +def get_precision(args): + if args.mixed_precision == "bf16": + dtype = torch.bfloat16 + elif args.mixed_precision == "fp16": + dtype = torch.float16 + else: + dtype = torch.float32 + return dtype + + +def save_video_grid(video, nrow=None): + b, t, h, w, c = video.shape + + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = torch.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype=torch.uint8) + + print(video_grid.shape) + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + + return video_grid + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def clean_caption(caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub('', 'person', caption) + # urls: + caption = re.sub( + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', + # noqa + '', caption) # regex for urls + caption = re.sub( + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', + # noqa + '', caption) # regex for urls + # html: + caption = BeautifulSoup(caption, features='html.parser').text + + # @ + caption = re.sub(r'@[\w\d]+\b', '', caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) + caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) + caption = re.sub(r'[\u3200-\u32ff]+', '', caption) + caption = re.sub(r'[\u3300-\u33ff]+', '', caption) + caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) + caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) + caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', + # noqa + '-', caption) + + # кавычки к одному стандарту + caption = re.sub(r'[`´«»“”¨]', '"', caption) + caption = re.sub(r'[‘’]', "'", caption) + + # " + caption = re.sub(r'"?', '', caption) + # & + caption = re.sub(r'&', '', caption) + + # ip adresses: + caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) + + # article ids: + caption = re.sub(r'\d:\d\d\s+$', '', caption) + + # \n + caption = re.sub(r'\\n', ' ', caption) + + # "#123" + caption = re.sub(r'#\d{1,3}\b', '', caption) + # "#12345.." + caption = re.sub(r'#\d{5,}\b', '', caption) + # "123456.." + caption = re.sub(r'\b\d{6,}\b', '', caption) + # filenames: + caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) + + # + caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" + + caption = re.sub(bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r'(?:\-|\_)') + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, ' ', caption) + + caption = basic_clean(caption) + + caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 + caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc + caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 + + caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) + caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) + caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) + caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) + caption = re.sub(r'\bpage\s+\d+\b', '', caption) + + caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... + + caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) + + caption = re.sub(r'\b\s+\:\s+', r': ', caption) + caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) + caption = re.sub(r'\s+', ' ', caption) + + caption.strip() + + caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) + caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) + caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) + caption = re.sub(r'^\.\S+$', '', caption) + + return caption.strip() diff --git a/src/videogen_hub/pipelines/seine/SEINEPipeline.py b/src/videogen_hub/pipelines/seine/SEINEPipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb14a3e89499dbc5924e5ee0a64d0f76317e356 --- /dev/null +++ b/src/videogen_hub/pipelines/seine/SEINEPipeline.py @@ -0,0 +1,135 @@ +from typing import List +from torch import _validate_compressed_sparse_indices +from torchvision.utils import save_image + +from videogen_hub import MODEL_PATH +from with_mask_sample import * + + +class SEINEPipeline(): + def __init__(self, seine_path: str = os.path.join(MODEL_PATH, "SEINE", "seine.pt"), + pretrained_model_path: str = os.path.join(MODEL_PATH, "SEINE", "stable-diffusion-v1-4"), + config_path: str = "src/videogen_hub/pipelines/seine/sample_i2v.yaml"): + """ + Load the configuration file and set the paths of models. + Args: + seine_path: The path of the downloaded seine pretrained model. + pretrained_model_path: The path of the downloaded stable diffusion pretrained model. + config_path: The path of the configuration file. + """ + self.config = OmegaConf.load(config_path) + self.config.ckpt = seine_path + self.config.pretrained_model_path = pretrained_model_path + + def infer_one_video(self, input_image, + text_prompt: List = [], + output_size: List = [240, 560], + num_frames: int = 16, + num_sampling_steps: int = 250, + seed: int = 42, + save_video: bool = False): + """ + Generate video based on provided input_image and text_prompt. + Args: + input_image: The input image to generate video. + text_prompt: The text prompt to generate video. + output_size: The size of the generated video. Defaults to [240, 560]. + num_frames: number of frames of the generated video. Defaults to 16. + num_sampling_steps: number of sampling steps to generate the video. Defaults to 250. + seed: The random seed for video generation. Defaults to 42. + save_video: save the video to the path in config if it is True. Not save if it is False. Defaults to False. + + Returns: + The generated video as tensor with shape (num_frames, channels, height, width). + + """ + + self.config.image_size = output_size + self.config.num_frames = num_frames + self.config.num_sampling_steps = num_sampling_steps + self.config.seed = seed + self.config.text_prompt = text_prompt + print(input_image, type(input_image) == str) + if type(input_image) == str: + self.config.input_path = input_image + else: + assert torch.is_tensor(input_image) + assert len(input_image.shape) == 3 + assert input_image.shape[0] == 3 + save_image(input_image, "src/videogen_hub/pipelines/seine/input_image.png") + + args = self.config + + # Setup PyTorch: + if args.seed: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "cpu" + + if args.ckpt is None: + raise ValueError("Please specify a checkpoint path using --ckpt ") + + # Load model: + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + args.image_h = args.image_size[0] + args.image_w = args.image_size[1] + args.latent_h = latent_h + args.latent_w = latent_w + print('loading model') + model = get_models(args).to(device) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + model.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # load model + ckpt_path = args.ckpt + state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] + model.load_state_dict(state_dict) + print('loading succeed') + + model.eval() + pretrained_model_path = args.pretrained_model_path + diffusion = create_diffusion(str(args.num_sampling_steps)) + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) + text_encoder = TextEmbedder(pretrained_model_path).to(device) + if args.use_fp16: + print('Warnning: using half percision for inferencing!') + vae.to(dtype=torch.float16) + model.to(dtype=torch.float16) + text_encoder.to(dtype=torch.float16) + + # prompt: + prompt = args.text_prompt + if prompt is None or prompt == []: + prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ') + else: + prompt = prompt[0] + prompt_base = prompt.replace(' ', '_') + prompt = prompt + args.additional_prompt + + if save_video: + if not os.path.exists(os.path.join(args.save_path)): + os.makedirs(os.path.join(args.save_path)) + + video_input, researve_frames = get_input(args) # f,c,h,w + video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w + + mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w + masked_video = video_input * (mask == 0) + + video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, + device, ) + video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, + 1) + + if save_video: + save_video_path = os.path.join(args.save_path, prompt_base + '.mp4') + torchvision.io.write_video(save_video_path, video_, fps=8) + print(f'save in {save_video_path}') + + return video_.permute(0, 3, 1, 2) diff --git a/src/videogen_hub/pipelines/seine/__init__.py b/src/videogen_hub/pipelines/seine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11956a9f66fcb9dab1a3f093a693722783e36ebe --- /dev/null +++ b/src/videogen_hub/pipelines/seine/__init__.py @@ -0,0 +1,2 @@ +import sys +sys.path.insert(0, './src/videogen_hub/pipelines/seine/') diff --git a/src/videogen_hub/pipelines/seine/datasets_seine/__init__.py b/src/videogen_hub/pipelines/seine/datasets_seine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/seine/datasets_seine/video_transforms.py b/src/videogen_hub/pipelines/seine/datasets_seine/video_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..3beee9f92ed281ff5342a9951f1e746065ac3b7f --- /dev/null +++ b/src/videogen_hub/pipelines/seine/datasets_seine/video_transforms.py @@ -0,0 +1,381 @@ +import torch +import numbers +import numpy as np +from PIL import Image + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i : i + h, j : j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) + +def resize_scale(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size[0] / min(H, W) + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + +def resize_with_scale_factor(clip, scale_factor, interpolation_mode): + return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False) + +def resize_scale_with_height(clip, target_size, interpolation_mode): + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size / H + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + +def resize_scale_with_weight(clip, target_size, interpolation_mode): + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size / W + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + # print(clip.shape) + th, tw = crop_size + if h < th or w < tw: + # print(h, w) + raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w)) + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def center_crop_using_short_edge(clip): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + if h < w: + th, tw = h, h + i = 0 + j = int(round((w - tw) / 2.0)) + else: + th, tw = w, w + i = int(round((h - th) / 2.0)) + j = 0 + return crop(clip, i, j, th, tw) + + +def random_shift_crop(clip): + ''' + Slide along the long edge, with the short edge as crop size + ''' + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + + if h <= w: + long_edge = w + short_edge = h + else: + long_edge = h + short_edge =w + + th, tw = short_edge, short_edge + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + # print(mean) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + Returns: + flipped clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) + + +class RandomCropVideo: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: randomly cropped video clip. + size is (T, C, OH, OW) + """ + i, j, h, w = self.get_params(clip) + return crop(clip, i, j, h, w) + + def get_params(self, clip): + h, w = clip.shape[-2:] + th, tw = self.size + + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + + return i, j, th, tw + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + +class CenterCropResizeVideo: + ''' + First use the short side for cropping length, + center crop video, then resize to the specified size + ''' + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + # print(clip.shape) + clip_center_crop = center_crop_using_short_edge(clip) + # print(clip_center_crop.shape) 320 512 + clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) + return clip_center_crop_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class CenterCropVideo: + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop(clip, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class NormalizeVideo: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) + """ + return normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class ResizeVideo(): + ''' + First use the short side for cropping length, + center crop video, then resize to the specified size + ''' + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + return clip_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + +# ------------------------------------------------------------ +# --------------------- Sampling --------------------------- +# ------------------------------------------------------------ diff --git a/src/videogen_hub/pipelines/seine/diffusion/__init__.py b/src/videogen_hub/pipelines/seine/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9dbf6cf0bd6b9d1a8f65e0a31e9a84cacc03189 --- /dev/null +++ b/src/videogen_hub/pipelines/seine/diffusion/__init__.py @@ -0,0 +1,47 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + # learn_sigma=True, + learn_sigma=False, # for unet + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/src/videogen_hub/pipelines/seine/diffusion/diffusion_utils.py b/src/videogen_hub/pipelines/seine/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/src/videogen_hub/pipelines/seine/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/src/videogen_hub/pipelines/seine/diffusion/gaussian_diffusion.py b/src/videogen_hub/pipelines/seine/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1b571e088fa245a072d2bb4320eea3d240df02 --- /dev/null +++ b/src/videogen_hub/pipelines/seine/diffusion/gaussian_diffusion.py @@ -0,0 +1,931 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + # diffuser stable diffusion + # beta_start=scale * 0.00085, + # beta_end=scale * 0.012, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, + mask=None, x_start=None, use_concat=False): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, F, C = x.shape[:3] + assert t.shape == (B,) + if use_concat: + model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs) + else: + model_output = model(x, t, **model_kwargs) + try: + model_output = model_output.sample # for tav unet + except: + pass + # model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + mask=None, + x_start=None, + use_concat=False + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + x_start=None, + use_concat=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + mask=mask, + x_start=x_start, + use_concat=use_concat + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + x_start=None, + use_concat=False + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + mask=None, + x_start=None, + use_concat=False + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + mask=None, + x_start=None, + use_concat=False + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + mask=mask, + x_start=x_start, + use_concat=use_concat + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + mask=None, + x_start=None, + use_concat=False + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, use_mask=False): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + if use_mask: + x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1) + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + try: + # model_output = model(x_t, t, **model_kwargs).sample + model_output = model_output.sample # for tav unet + except: + pass + # model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, F, C = x_t.shape[:3] + assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + # assert model_output.shape == target.shape == x_start.shape + if use_mask: + terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2) + else: + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/src/videogen_hub/pipelines/seine/diffusion/respace.py b/src/videogen_hub/pipelines/seine/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..31c0092380367db477f154a39bf172ff64712fc4 --- /dev/null +++ b/src/videogen_hub/pipelines/seine/diffusion/respace.py @@ -0,0 +1,130 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import torch +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/src/videogen_hub/pipelines/seine/diffusion/timestep_sampler.py b/src/videogen_hub/pipelines/seine/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/src/videogen_hub/pipelines/seine/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/src/videogen_hub/pipelines/seine/models/__init__.py b/src/videogen_hub/pipelines/seine/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e25fcc16f08c39021758964d5cd7358ff36f1c6a --- /dev/null +++ b/src/videogen_hub/pipelines/seine/models/__init__.py @@ -0,0 +1,33 @@ +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from .unet import UNet3DConditionModel +from torch.optim.lr_scheduler import LambdaLR + +def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit + from torch.optim.lr_scheduler import LambdaLR + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'warmup': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + +def get_models(args): + if 'UNet' in args.model: + pretrained_model_path = args.pretrained_model_path + return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask) + else: + raise '{} Model Not Supported!'.format(args.model) + \ No newline at end of file diff --git a/src/videogen_hub/pipelines/seine/models/attention.py b/src/videogen_hub/pipelines/seine/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..689f017d36b431f730257d3c94413a0eddc7287c --- /dev/null +++ b/src/videogen_hub/pipelines/seine/models/attention.py @@ -0,0 +1,968 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from dataclasses import dataclass +from typing import Optional + +import math +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm +from rotary_embedding_torch import RotaryEmbedding +from typing import Callable, Optional +from einops import rearrange, repeat + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def exists(x): + return x is not None + + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + # print('num head', heads) + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + # print(use_relative_position) + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.rotary_emb = RotaryEmbedding(min(32, dim_head)) + # # print(dim_head) + # # print(heads) + # # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265 + # self.max_position_embeddings = 32 + # self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head) + + # self.dropout = nn.Dropout(dropout) + + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + + # print('before reshpape query shape', query.shape) + dim = query.shape[-1] + if not self.use_relative_position: + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # print('after reshape query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if not self.use_relative_position: + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + + if attention_mask is not None: + # print('attention_mask', attention_mask.shape) + # print('attention_scores', attention_scores.shape) + # exit() + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + # print(attention_probs.shape) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + # print(attention_probs.shape) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + # print(hidden_states.shape) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # print(hidden_states.shape) + # exit() + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + if self.training: + video_length = hidden_states.shape[2] - use_image_num + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states_length = encoder_hidden_states.shape[1] + encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...] + encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous() + else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous() + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length, + use_image_num=use_image_num, + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous() + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + # print(only_cross_attention) + self.use_ada_layer_norm = num_embeds_ada_norm is not None + # print(self.use_ada_layer_norm) + self.use_first_frame = use_first_frame + + # Spatial-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # # SC-Attn + # self.attn1 = SparseCausalAttention( + # query_dim=dim, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # cross_attention_dim=cross_attention_dim if only_cross_attention else None, + # upcast_attention=upcast_attention, + # ) + # self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Text Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # # Temp Frame-Cross-Attn; add tahn scale factor + # self.attn_fcross = SparseCausalAttention( + # query_dim=dim, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # cross_attention_dim=cross_attention_dim if only_cross_attention else None, + # upcast_attention=upcast_attention, + # ) + # self.norm_fcross = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + # nn.init.zeros_(self.attn_fcross.to_out[0].weight.data) + + # Temp + self.attn_temp = TemporalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + rotary_emb=rotary_emb, + ) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None): + + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_fcross._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states + + # # SparseCausal-Attention + # norm_hidden_states = ( + # self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + # ) + + # if self.only_cross_attention: + # hidden_states = ( + # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + # ) + # else: + # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + # # Temporal FrameCross Attention + # norm_hidden_states = ( + # self.norm_fcross(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_fcross(hidden_states) + # ) + # hidden_states = self.attn_fcross( + # norm_hidden_states, attention_mask=attention_mask, video_length=video_length, use_image_num=use_image_num) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Temporal Attention + if self.training: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous() + hidden_states_video = hidden_states[:, :video_length, :] + hidden_states_image = hidden_states[:, video_length:, :] + # print(hidden_states_video.shape) + # print(hidden_states_image.shape) + # if self.training: + # # prepare attention mask; mask images in temporal attention + # attention_mask_shape = (video_length + use_image_num) // 8 + 1 + # video_image_length = video_length + use_image_num + # attention_mask = torch.zeros([8 * attention_mask_shape, 8 * attention_mask_shape], + # dtype=hidden_states.dtype, device=hidden_states.device)[:video_image_length, :video_image_length] + # attention_mask[:, video_length:] = -math.inf + norm_hidden_states_video = ( + self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video) + ) + # print(norm_hidden_states.shape) + hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + else: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous() + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + # print(norm_hidden_states.shape) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward_video(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length).contiguous() + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + key = rearrange(key, "b f d c -> (b f) d c").contiguous() + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length).contiguous() + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + value = rearrange(value, "b f d c -> (b f) d c").contiguous() + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def forward_image(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + # if self.use_relative_position: + # print('before attention query shape', query.shape) + dim = query.shape[-1] + if not self.use_relative_position: + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # if self.use_relative_position: + # print('before attention query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if not self.use_relative_position: + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_image_num=None): + if self.training: + # print(use_image_num) + hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous() + hidden_states_video = hidden_states[:, :video_length, ...] + hidden_states_image = hidden_states[:, video_length:, ...] + hidden_states_video = rearrange(hidden_states_video, 'b f d c -> (b f) d c').contiguous() + hidden_states_image = rearrange(hidden_states_image, 'b f d c -> (b f) d c').contiguous() + hidden_states_video = self.forward_video(hidden_states=hidden_states_video, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + video_length=video_length) + # print('hidden_states_video', hidden_states_video.shape) + hidden_states_image = self.forward_image(hidden_states=hidden_states_image, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask) + # print('hidden_states_image', hidden_states_image.shape) + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=0) + return hidden_states + # exit() + else: + return self.forward_video(hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + video_length=video_length) + +class TemporalAttention(CrossAttention): + def __init__(self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + rotary_emb=None): + super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups) + # relative time positional embeddings + self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet + self.rotary_emb = rotary_emb + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device) + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + # reshape for adding time positional bais + query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + + # torch.baddbmm only accepte 3-D tensor + # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm + # attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2)) + if exists(self.rotary_emb): + query = self.rotary_emb.rotate_queries_or_keys(query) + key = self.rotary_emb.rotate_queries_or_keys(key) + + attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key) + # print('attention_scores shape', attention_scores.shape) + # print('time_rel_pos_bias shape', time_rel_pos_bias.shape) + # print('attention_mask shape', attention_mask.shape) + + attention_scores = attention_scores + time_rel_pos_bias + # print(attention_scores.shape) + + # bert from huggin face + # attention_scores = attention_scores / math.sqrt(self.dim_head) + + # # Normalize the attention scores to probabilities. + # attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + if attention_mask is not None: + # add attention mask + attention_scores = attention_scores + attention_mask + + # vdm + attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach() + + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + # print(attention_probs[0][0]) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + # hidden_states = torch.matmul(attention_probs, value) + hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value) + # print(hidden_states.shape) + # hidden_states = self.same_batch_dim_to_heads(hidden_states) + hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)') + # print(hidden_states.shape) + # exit() + return hidden_states + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames \ No newline at end of file diff --git a/src/videogen_hub/pipelines/seine/models/clip.py b/src/videogen_hub/pipelines/seine/models/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..249ff2a1680d580d94d4a18c1db5f538a81c043d --- /dev/null +++ b/src/videogen_hub/pipelines/seine/models/clip.py @@ -0,0 +1,123 @@ +import numpy +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPTextModel + +import transformers +transformers.logging.set_verbosity_error() + +""" +Will encounter following warning: +- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task +or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). +- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model +that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). + +https://github.com/CompVis/stable-diffusion/issues/97 +according to this issue, this warning is safe. + +This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. +You can safely ignore the warning, it is not an error. + +This clip usage is from U-ViT and same with Stable Diffusion. +""" + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77): + def __init__(self, path, device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer") + self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder') + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class TextEmbedder(nn.Module): + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + def __init__(self, path, dropout_prob=0.1): + super().__init__() + self.text_encodder = FrozenCLIPEmbedder(path=path) + self.dropout_prob = dropout_prob + + def token_drop(self, text_prompts, force_drop_ids=None): + """ + Drops text to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob + else: + # TODO + drop_ids = force_drop_ids == 1 + labels = list(numpy.where(drop_ids, "", text_prompts)) + # print(labels) + return labels + + def forward(self, text_prompts, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + text_prompts = self.token_drop(text_prompts, force_drop_ids) + embeddings = self.text_encodder(text_prompts) + return embeddings + + +if __name__ == '__main__': + + r""" + Returns: + + Examples from CLIPTextModel: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base', + dropout_prob=0.00001).to(device) + + text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]] + # text_prompt = ('None', 'None', 'None') + output = text_encoder(text_prompts=text_prompt, train=False) + # print(output) + print(output.shape) + # print(output.shape) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/seine/models/resnet.py b/src/videogen_hub/pipelines/seine/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c1307ad9a0fe26b6b1eca913b024c188df2dd86e --- /dev/null +++ b/src/videogen_hub/pipelines/seine/models/resnet.py @@ -0,0 +1,212 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/seine/models/unet.py b/src/videogen_hub/pipelines/seine/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..24f193b194bed049205134bc121d97dc03252541 --- /dev/null +++ b/src/videogen_hub/pipelines/seine/models/unet.py @@ -0,0 +1,691 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import math +import json +import torch +import einops +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + +try: + from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from .resnet import InflatedConv3d +except: + from unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from resnet import InflatedConv3d + +from rotary_embedding_torch import RotaryEmbedding + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, # 64 + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + + # print(use_first_frame) + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # print(only_cross_attention) + # print(type(only_cross_attention)) + # exit() + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + # print(only_cross_attention) + # exit() + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # print(attention_head_dim) + # exit() + + rotary_emb = RotaryEmbedding(32) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + # relative time positional embeddings + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + # print(emb.shape) # torch.Size([3, 1280]) + # print(class_emb.shape) # torch.Size([3, 1280]) + emb = emb + class_emb + + if self.use_relative_position: + frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device) + else: + frame_rel_pos_bias = None + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + use_image_num=use_image_num, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num, + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + use_image_num=use_image_num, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + # print(sample.shape) + + if not return_dict: + return (sample,) + sample = UNet3DConditionOutput(sample=sample) + return sample + + def forward_with_cfg(self, + x, + t, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = None, + cfg_scale=4.0, + use_fp16=False): + """ + Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + if use_fp16: + combined = combined.to(dtype=torch.float16) + model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = model_out[:, :4], model_out[:, 4:] + # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + + # the content of the config file + # { + # "_class_name": "UNet2DConditionModel", + # "_diffusers_version": "0.2.2", + # "act_fn": "silu", + # "attention_head_dim": 8, + # "block_out_channels": [ + # 320, + # 640, + # 1280, + # 1280 + # ], + # "center_input_sample": false, + # "cross_attention_dim": 768, + # "down_block_types": [ + # "CrossAttnDownBlock2D", + # "CrossAttnDownBlock2D", + # "CrossAttnDownBlock2D", + # "DownBlock2D" + # ], + # "downsample_padding": 1, + # "flip_sin_to_cos": true, + # "freq_shift": 0, + # "in_channels": 4, + # "layers_per_block": 2, + # "mid_block_scale_factor": 1, + # "norm_eps": 1e-05, + # "norm_num_groups": 32, + # "out_channels": 4, + # "sample_size": 64, + # "up_block_types": [ + # "UpBlock2D", + # "CrossAttnUpBlock2D", + # "CrossAttnUpBlock2D", + # "CrossAttnUpBlock2D" + # ] + # } + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + # config["use_first_frame"] = True + + config["use_first_frame"] = False + if use_concat: + config["in_channels"] = 9 + # config["use_relative_position"] = True + + # # tmp + # config["class_embed_type"] = "timestep" + # config["num_class_embeds"] = 100 + + from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin + + # {'_class_name': 'UNet3DConditionModel', + # '_diffusers_version': '0.2.2', + # 'act_fn': 'silu', + # 'attention_head_dim': 8, + # 'block_out_channels': [320, 640, 1280, 1280], + # 'center_input_sample': False, + # 'cross_attention_dim': 768, + # 'down_block_types': + # ['CrossAttnDownBlock3D', + # 'CrossAttnDownBlock3D', + # 'CrossAttnDownBlock3D', + # 'DownBlock3D'], + # 'downsample_padding': 1, + # 'flip_sin_to_cos': True, + # 'freq_shift': 0, + # 'in_channels': 4, + # 'layers_per_block': 2, + # 'mid_block_scale_factor': 1, + # 'norm_eps': 1e-05, + # 'norm_num_groups': 32, + # 'out_channels': 4, + # 'sample_size': 64, + # 'up_block_types': + # ['UpBlock3D', + # 'CrossAttnUpBlock3D', + # 'CrossAttnUpBlock3D', + # 'CrossAttnUpBlock3D']} + + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + if use_concat: + new_state_dict = {} + conv_in_weight = state_dict["conv_in.weight"] + new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype) + + for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]): + new_conv_weight[:, j] = conv_in_weight[:, i] + new_state_dict["conv_in.weight"] = new_conv_weight + new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"] + for k, v in model.state_dict().items(): + # print(k) + if '_temp.' in k: + new_state_dict.update({k: v}) + if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross + k = k.replace('attn_fcross', 'attn1') + state_dict.update({k: state_dict[k]}) + if 'norm_fcross' in k: + k = k.replace('norm_fcross', 'norm1') + state_dict.update({k: state_dict[k]}) + + if 'conv_in' in k: + continue + else: + new_state_dict[k] = v + # # tmp + # if 'class_embedding' in k: + # state_dict.update({k: v}) + # breakpoint() + model.load_state_dict(new_state_dict) + else: + for k, v in model.state_dict().items(): + # print(k) + if '_temp' in k: + state_dict.update({k: v}) + if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross + k = k.replace('attn_fcross', 'attn1') + state_dict.update({k: state_dict[k]}) + if 'norm_fcross' in k: + k = k.replace('norm_fcross', 'norm1') + state_dict.update({k: state_dict[k]}) + + model.load_state_dict(state_dict) + + return model diff --git a/src/videogen_hub/pipelines/seine/models/unet_blocks.py b/src/videogen_hub/pipelines/seine/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..849c10539c7039840c93631c5201069119d3c306 --- /dev/null +++ b/src/videogen_hub/pipelines/seine/models/unet_blocks.py @@ -0,0 +1,648 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +from torch import nn + +try: + from .attention import Transformer3DModel + from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +except: + from attention import Transformer3DModel + from resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, +): + # print(down_block_type) + # print(use_first_frame) + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, + ): + super().__init__() + resnets = [] + attentions = [] + + # print(use_first_frame) + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def create_custom_forward_attn(module, return_dict=None, use_image_num=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) + else: + return module(*inputs, use_image_num=use_image_num) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + use_image_num=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def create_custom_forward_attn(module, return_dict=None, use_image_num=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) + else: + return module(*inputs, use_image_num=use_image_num) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/src/videogen_hub/pipelines/seine/models/utils.py b/src/videogen_hub/pipelines/seine/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76e205d1709568885e58b56888f62c8805ad4f91 --- /dev/null +++ b/src/videogen_hub/pipelines/seine/models/utils.py @@ -0,0 +1,215 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch + +import numpy as np +import torch.nn as nn + +from einops import repeat + + +################################################################################# +# Unet Utils # +################################################################################# + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# class HybridConditioner(nn.Module): + +# def __init__(self, c_concat_config, c_crossattn_config): +# super().__init__() +# self.concat_conditioner = instantiate_from_config(c_concat_config) +# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + +# def forward(self, c_concat, c_crossattn): +# c_concat = self.concat_conditioner(c_concat) +# c_crossattn = self.crossattn_conditioner(c_crossattn) +# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params \ No newline at end of file diff --git a/src/videogen_hub/pipelines/seine/sample_i2v.yaml b/src/videogen_hub/pipelines/seine/sample_i2v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82cb065844ec94d360f2bd9ae8ad430ab5e17e5d --- /dev/null +++ b/src/videogen_hub/pipelines/seine/sample_i2v.yaml @@ -0,0 +1,30 @@ +# path config: +ckpt: "./checkpoints/SEINE/seine.pt" +pretrained_model_path: "./checkpoints/SEINE/stable-diffusion-v1-4/" +input_path: './src/videogen_hub/pipelines/seine/input_image.png' +save_path: "./src/videogen_hub/pipelines/seine/results/i2v/" + + +# model config: +model: UNet +num_frames: 16 +image_size: [240, 560] +#image_size: [320, 512] +# image_size: [512, 512] + +# model speedup +use_fp16: True +enable_xformers_memory_efficient_attention: True + +# sample config: +seed: +run_time: 13 +cfg_scale: 8.0 +sample_method: 'ddpm' +num_sampling_steps: 250 +text_prompt: [] +additional_prompt: ", slow motion." +negative_prompt: "" +do_classifier_free_guidance: True +mask_type: "first1" +use_mask: True diff --git a/src/videogen_hub/pipelines/seine/utils.py b/src/videogen_hub/pipelines/seine/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e54d5a592a9f94894e98c011732b0d91163ea9b --- /dev/null +++ b/src/videogen_hub/pipelines/seine/utils.py @@ -0,0 +1,388 @@ +import os +import math +import torch +import logging +import subprocess +import numpy as np +import torch.distributed as dist + +# from torch._six import inf +from torch import inf +from PIL import Image +from typing import Union, Iterable +from collections import OrderedDict +from torch.utils.tensorboard import SummaryWriter + +_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] + + +################################################################################# +# Training Helper Functions # +################################################################################# +def fetch_files_by_numbers(start_number, count, file_list): + file_numbers = range(start_number, start_number + count) + found_files = [] + for file_number in file_numbers: + file_number_padded = str(file_number).zfill(2) + for file_name in file_list: + if file_name.endswith(file_number_padded + '.csv'): + found_files.append(file_name) + break # Stop searching once a file is found for the current number + return found_files + + +################################################################################# +# Training Clip Gradients # +################################################################################# + +def get_grad_norm( + parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + return total_norm + + +def clip_grad_norm_( + parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, + error_if_nonfinite: bool = False, clip_grad=True) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + # print(total_norm) + + if clip_grad: + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f'The total norm of order {norm_type} for gradients from ' + '`parameters` is non-finite, so it cannot be clipped. To disable ' + 'this error and scale the gradients by the non-finite norm anyway, ' + 'set `error_if_nonfinite=False`') + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for g in grads: + g.detach().mul_(clip_coef_clamped.to(g.device)) + # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + # print(gradient_cliped) + return total_norm + + +def separation_content_motion(video_clip): + """ + separate coontent and motion in a given video + Args: + video_clip, a give video clip, [B F C H W] + + Return: + base frame, [B, 1, C, H, W] + motions, [B, F-1, C, H, W], + the first is base frame, + the second is motions based on base frame + """ + total_frames = video_clip.shape[1] + base_frame = video_clip[0] + motions = [video_clip[i] - base_frame for i in range(1, total_frames)] + motions = torch.cat(motions, dim=1) + return base_frame, motions + + +def get_experiment_dir(root_dir, args): + if args.use_compile: + root_dir += '-Compile' # speedup by torch compile + if args.fixed_spatial: + root_dir += '-FixedSpa' + if args.enable_xformers_memory_efficient_attention: + root_dir += '-Xfor' + if args.gradient_checkpointing: + root_dir += '-Gc' + if args.mixed_precision: + root_dir += '-Amp' + if args.image_size == 512: + root_dir += '-512' + return root_dir + + +################################################################################# +# Training Logger # +################################################################################# + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + # format='[\033[34m%(asctime)s\033[0m] %(message)s', + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def create_accelerate_logger(logging_dir, is_main_process=False): + """ + Create a logger that writes to a log file and stdout. + """ + if is_main_process: # real logger + logging.basicConfig( + level=logging.INFO, + # format='[\033[34m%(asctime)s\033[0m] %(message)s', + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def create_tensorboard(tensorboard_dir): + """ + Create a tensorboard that saves losses. + """ + if dist.get_rank() == 0: # real tensorboard + # tensorboard + writer = SummaryWriter(tensorboard_dir) + + return writer + + +def write_tensorboard(writer, *args): + ''' + write the loss information to a tensorboard file. + Only for pytorch DDP mode. + ''' + if dist.get_rank() == 0: # real tensorboard + writer.add_scalar(args[0], args[1], args[2]) + + +################################################################################# +# EMA Update/ DDP Training Utils # +################################################################################# + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + if param.requires_grad: + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def cleanup(): + """ + End DDP training. + """ + dist.destroy_process_group() + + +def setup_distributed(backend="nccl", port=None): + """Initialize distributed training environment. + support both slurm and torch.distributed.launch + see torch.distributed.init_process_group() for more details + """ + num_gpus = torch.cuda.device_count() + + if "SLURM_JOB_ID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + # specify master port + if port is not None: + os.environ["MASTER_PORT"] = str(port) + elif "MASTER_PORT" not in os.environ: + # os.environ["MASTER_PORT"] = "29566" + os.environ["MASTER_PORT"] = str(29566 + num_gpus) + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank % num_gpus) + os.environ["RANK"] = str(rank) + else: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # torch.cuda.set_device(rank % num_gpus) + + dist.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + ) + + +################################################################################# +# Testing Utils # +################################################################################# + +def save_video_grid(video, nrow=None): + b, t, h, w, c = video.shape + + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = torch.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype=torch.uint8) + + print(video_grid.shape) + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + + return video_grid + + +def save_videos_grid_tav(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): + from einops import rearrange + import imageio + import torchvision + + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + # os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps) + + +################################################################################# +# MMCV Utils # +################################################################################# + + +def collect_env(): + # Copyright (c) OpenMMLab. All rights reserved. + from mmcv.utils import collect_env as collect_base_env + from mmcv.utils import get_git_hash + """Collect the information of the running environments.""" + + env_info = collect_base_env() + env_info['MMClassification'] = get_git_hash()[:7] + + for name, val in env_info.items(): + print(f'{name}: {val}') + + print(torch.cuda.get_arch_list()) + print(torch.version.cuda) + + +################################################################################# +# Long video generation Utils # +################################################################################# + +def mask_generation_before(mask_type, shape, dtype, device, dropout_prob=0.0, use_image_num=0): + b, f, c, h, w = shape + if mask_type.startswith('first'): + num = int(mask_type.split('first')[-1]) + mask_f = torch.cat([torch.zeros(1, num, 1, 1, 1, dtype=dtype, device=device), + torch.ones(1, f - num, 1, 1, 1, dtype=dtype, device=device)], dim=1) + mask = mask_f.expand(b, -1, c, h, w) + elif mask_type.startswith('all'): + mask = torch.ones(b, f, c, h, w, dtype=dtype, device=device) + elif mask_type.startswith('onelast'): + num = int(mask_type.split('onelast')[-1]) + mask_one = torch.zeros(1, 1, 1, 1, 1, dtype=dtype, device=device) + mask_mid = torch.ones(1, f - 2 * num, 1, 1, 1, dtype=dtype, device=device) + mask_last = torch.zeros_like(mask_one) + mask = torch.cat([mask_one] * num + [mask_mid] + [mask_last] * num, dim=1) + mask = mask.expand(b, -1, c, h, w) + else: + raise ValueError(f"Invalid mask type: {mask_type}") + return mask diff --git a/src/videogen_hub/pipelines/seine/with_mask_sample.py b/src/videogen_hub/pipelines/seine/with_mask_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..b7acf9ee505ebc3ab723941ba5cfbd3edcab3759 --- /dev/null +++ b/src/videogen_hub/pipelines/seine/with_mask_sample.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Sample new images from a pre-trained DiT. +""" +import os +import sys + +try: + import utils + from diffusion import create_diffusion +except: + # sys.path.append(os.getcwd()) + sys.path.append(os.path.split(sys.path[0])[0]) + # sys.path[0] + # os.path.split(sys.path[0]) + import utils + + from diffusion import create_diffusion + +import torch + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +from einops import rearrange +from PIL import Image +import numpy as np +from torchvision import transforms + +sys.path.append("..") +from datasets_seine import video_transforms +from natsort import natsorted + + +def get_input(args): + input_path = args.input_path + transform_video = transforms.Compose([ + video_transforms.ToTensorVideo(), # TCHW + video_transforms.ResizeVideo((args.image_h, args.image_w)), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + if input_path is not None: + print(f'loading video from {input_path}') + if os.path.isdir(input_path): + file_list = os.listdir(input_path) + video_frames = [] + if args.mask_type.startswith('onelast'): + num = int(args.mask_type.split('onelast')[-1]) + # get first and last frame + first_frame_path = os.path.join(input_path, natsorted(file_list)[0]) + last_frame_path = os.path.join(input_path, natsorted(file_list)[-1]) + first_frame = torch.as_tensor( + np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0) + last_frame = torch.as_tensor( + np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0) + for i in range(num): + video_frames.append(first_frame) + # add zeros to frames + num_zeros = args.num_frames - 2 * num + for i in range(num_zeros): + zeros = torch.zeros_like(first_frame) + video_frames.append(zeros) + for i in range(num): + video_frames.append(last_frame) + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + else: + for file in file_list: + if file.endswith('jpg') or file.endswith('png'): + image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0) + video_frames.append(image) + else: + continue + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + return video_frames, n + elif os.path.isfile(input_path): + _, full_file_name = os.path.split(input_path) + file_name, extension = os.path.splitext(full_file_name) + if extension == '.jpg' or extension == '.png': + print("loading the input image") + video_frames = [] + num = int(args.mask_type.split('first')[-1]) + first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0) + for i in range(num): + video_frames.append(first_frame) + num_zeros = args.num_frames - num + for i in range(num_zeros): + zeros = torch.zeros_like(first_frame) + video_frames.append(zeros) + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + return video_frames, n + else: + raise TypeError(f'{extension} is not supported !!') + else: + raise ValueError('Please check your path input!!') + else: + raise ValueError('Need to give a video or some images') + + +def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device, ): + b, f, c, h, w = video_input.shape + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + + # prepare inputs + if args.use_fp16: + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, + device=device) # b,c,f,h,w + masked_video = masked_video.to(dtype=torch.float16) + mask = mask.to(dtype=torch.float16) + else: + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w + + masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() + masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) + masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() + mask = torch.nn.functional.interpolate(mask[:, :, 0, :], size=(latent_h, latent_w)).unsqueeze(1) + + # classifier_free_guidance + if args.do_classifier_free_guidance: + masked_video = torch.cat([masked_video] * 2) + mask = torch.cat([mask] * 2) + z = torch.cat([z] * 2) + prompt_all = [prompt] + [args.negative_prompt] + + else: + masked_video = masked_video + mask = mask + z = z + prompt_all = [prompt] + + text_prompt = text_encoder(text_prompts=prompt_all, train=False) + model_kwargs = dict(encoder_hidden_states=text_prompt, + class_labels=None, + cfg_scale=args.cfg_scale, + use_fp16=args.use_fp16, ) # tav unet + + # Sample video: + if args.sample_method == 'ddim': + samples = diffusion.ddim_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, + device=device, \ + mask=mask, x_start=masked_video, use_concat=args.use_mask + ) + elif args.sample_method == 'ddpm': + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, + device=device, \ + mask=mask, x_start=masked_video, use_concat=args.use_mask + ) + samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] + if args.use_fp16: + samples = samples.to(dtype=torch.float16) + + video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] + video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] + return video_clip diff --git a/src/videogen_hub/pipelines/show_1/__init__.py b/src/videogen_hub/pipelines/show_1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa0d19aab9230ba40d34d32573738e768fc4796 --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/__init__.py @@ -0,0 +1,2 @@ +import sys +sys.path.insert(0, './src/videogen_hub/pipelines/show_1/') \ No newline at end of file diff --git a/src/videogen_hub/pipelines/show_1/run_inference.py b/src/videogen_hub/pipelines/show_1/run_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..43f60f0a7ec86a83bceafb4c427fdb74512d81af --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/run_inference.py @@ -0,0 +1,208 @@ +import os +import imageio +from PIL import Image +from typing import List + +import torch +import torch.nn.functional as F + +from diffusers import IFSuperResolutionPipeline, VideoToVideoSDPipeline +from diffusers.utils.torch_utils import randn_tensor + + +class ShowOnePipeline(): + def __init__(self, base_path, interp_path, deepfloyd_path, sr1_path, sr2_path): + """ + Downloading the necessary models from huggingface and utilize them to load their pipelines, + https://github.com/showlab/Show-1 + """ + from .showone.pipelines import TextToVideoIFPipeline, TextToVideoIFInterpPipeline, \ + TextToVideoIFSuperResolutionPipeline + from .showone.pipelines.pipeline_t2v_base_pixel import tensor2vid + from .showone.pipelines.pipeline_t2v_sr_pixel_cond import TextToVideoIFSuperResolutionPipeline_Cond + + self.tensor2vid = tensor2vid + # Base Model + # When using "showlab/show-1-base-0.0", it's advisable to increase the number of inference steps (e.g., 100) + # and opt for a larger guidance scale (e.g., 12.0) to enhance visual quality. + + self.pipe_base = TextToVideoIFPipeline.from_pretrained( + base_path, + torch_dtype=torch.float16, + variant="fp16" + ) + self.pipe_base.enable_model_cpu_offload() + + # Interpolation Model + + self.pipe_interp_1 = TextToVideoIFInterpPipeline.from_pretrained( + interp_path, + torch_dtype=torch.float16, + variant="fp16" + ) + self.pipe_interp_1.enable_model_cpu_offload() + + # Super-Resolution Model 1 + # Image super-resolution model from DeepFloyd https://huggingface.co/DeepFloyd/IF-II-L-v1.0 + # pretrained_model_path = "./checkpoints/DeepFloyd/IF-II-L-v1.0" + + self.pipe_sr_1_image = IFSuperResolutionPipeline.from_pretrained( + deepfloyd_path, + text_encoder=None, + torch_dtype=torch.float16, + variant="fp16" + ) + self.pipe_sr_1_image.enable_model_cpu_offload() + + self.pipe_sr_1_cond = TextToVideoIFSuperResolutionPipeline_Cond.from_pretrained( + sr1_path, + torch_dtype=torch.float16 + ) + self.pipe_sr_1_cond.enable_model_cpu_offload() + + # Super-Resolution Model 2 + + self.pipe_sr_2 = VideoToVideoSDPipeline.from_pretrained( + sr2_path, + torch_dtype=torch.float16 + ) + self.pipe_sr_2.enable_model_cpu_offload() + self.pipe_sr_2.enable_vae_slicing() + + def inference(self, prompt: str = "", + negative_prompt: str = "", + output_size: List[int] = [240, 560], + initial_num_frames: int = 8, + scaling_factor: int = 4, + seed: int = 42): + """ + Generates a single video based on a textual prompt. The output is a tensor representing the video. + The initial_num_frames is set to be 8 as shown in paper. + https://github.com/showlab/Show-1 + + Args: + prompt (str, optional): The text prompt that guides the video generation. If not specified, the video generation will rely solely on the input image. Defaults to "". + negative_prompt (str, optional): The negative prompt that guided the video generation. Defaults to "". + output_size (list, optional): Specifies the resolution of the output video as [height, width]. Defaults to [240, 560]. + initial_num_frames: the number of frames generated using the base model. Defaults to 8 as proposed in the paper. + scaling_factor: The amount of scaling during the interpolation step. Defaults to 4 as proposed in the paper, which interpolates number of frames from 8 to 29. + seed (int, optional): A seed value for random number generation, ensuring reproducibility of the video generation process. Defaults to 42. + + Returns: + The generated video as a tensor with shape (num_frames, channels, height, width). + """ + # Inference + # Text embeds + prompt_embeds, negative_embeds = self.pipe_base.encode_prompt(prompt) + + # Keyframes generation (8x64x40, 2fps) + video_frames = self.pipe_base( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + num_frames=initial_num_frames, + height=40, + width=64, + num_inference_steps=75, + guidance_scale=9.0, + generator=torch.manual_seed(seed), + output_type="pt" + ).frames + + # Frame interpolation (8x64x40, 2fps -> 29x64x40, 7.5fps) + + bsz, channel, num_frames_1, height, width = video_frames.shape + + k = scaling_factor + + new_num_frames = (k - 1) * (num_frames_1 - 1) + num_frames_1 + new_video_frames = torch.zeros((bsz, channel, new_num_frames, height, width), + dtype=video_frames.dtype, device=video_frames.device) + new_video_frames[:, :, torch.arange(0, new_num_frames, k), ...] = video_frames + init_noise = randn_tensor((bsz, channel, k + 1, height, width), dtype=video_frames.dtype, + device=video_frames.device, generator=torch.manual_seed(seed)) + + for i in range(num_frames_1 - 1): + batch_i = torch.zeros((bsz, channel, k + 1, height, width), dtype=video_frames.dtype, + device=video_frames.device) + batch_i[:, :, 0, ...] = video_frames[:, :, i, ...] + batch_i[:, :, -1, ...] = video_frames[:, :, i + 1, ...] + batch_i = self.pipe_interp_1( + pixel_values=batch_i, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + num_frames=batch_i.shape[2], + height=40, + width=64, + num_inference_steps=75, + guidance_scale=4.0, + generator=torch.manual_seed(seed), + output_type="pt", + init_noise=init_noise, + cond_interpolation=True, + ).frames + + new_video_frames[:, :, i * k:i * k + k + 1, ...] = batch_i + + video_frames = new_video_frames + + # Super-resolution 1 (29x64x40 -> 29x256x160) + bsz, channel, num_frames_2, height, width = video_frames.shape + window_size, stride = 8, 7 + new_video_frames = torch.zeros( + (bsz, channel, num_frames_2, height * 4, width * 4), + dtype=video_frames.dtype, + device=video_frames.device) + for i in range(0, num_frames_2 - window_size + 1, stride): + batch_i = video_frames[:, :, i:i + window_size, ...] + all_frame_cond = None + + if i == 0: + first_frame_cond = self.pipe_sr_1_image( + image=video_frames[:, :, 0, ...], + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + height=height * 4, + width=width * 4, + num_inference_steps=70, + guidance_scale=4.0, + noise_level=150, + generator=torch.manual_seed(seed), + output_type="pt" + ).images + first_frame_cond = first_frame_cond.unsqueeze(2) + else: + first_frame_cond = new_video_frames[:, :, i:i + 1, ...] + + batch_i = self.pipe_sr_1_cond( + image=batch_i, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + first_frame_cond=first_frame_cond, + height=height * 4, + width=width * 4, + num_inference_steps=125, + guidance_scale=7.0, + noise_level=250, + generator=torch.manual_seed(seed), + output_type="pt" + ).frames + new_video_frames[:, :, i:i + window_size, ...] = batch_i + + video_frames = new_video_frames + + # Super-resolution 2 (29x256x160 -> 29x576x320) + video_frames = [Image.fromarray(frame).resize((output_size[1], output_size[0])) for frame in + self.tensor2vid(video_frames.clone())] + video_frames = self.pipe_sr_2( + prompt, + negative_prompt=negative_prompt, + video=video_frames, + strength=0.8, + num_inference_steps=50, + generator=torch.manual_seed(seed), + output_type="pt" + ).frames + + output = video_frames.squeeze() + + return output diff --git a/src/videogen_hub/pipelines/show_1/showone/__init__.py b/src/videogen_hub/pipelines/show_1/showone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/show_1/showone/models/__init__.py b/src/videogen_hub/pipelines/show_1/showone/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..520074e727890df83633081279b21807000deacb --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/showone/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .unet_3d_condition import UNet3DConditionModel \ No newline at end of file diff --git a/src/videogen_hub/pipelines/show_1/showone/models/transformer_temporal.py b/src/videogen_hub/pipelines/show_1/showone/models/transformer_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..06736ff273bf79a4ed972e5f47b8755b0437607e --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/showone/models/transformer_temporal.py @@ -0,0 +1,179 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.modeling_utils import ModelMixin + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + The output of [`TransformerTemporalModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + The [`TransformerTemporal`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, channel, num_frames) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/src/videogen_hub/pipelines/show_1/showone/models/unet_3d_blocks.py b/src/videogen_hub/pipelines/show_1/showone/models/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..da81435411b1f89ae6010b9221406df1f0fdc07f --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/showone/models/unet_3d_blocks.py @@ -0,0 +1,1619 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import is_torch_version, logging +from diffusers.models.normalization import AdaGroupNorm +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D +from diffusers.models.transformer_2d import Transformer2DModel +from diffusers.models.transformer_temporal import TransformerTemporalModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SimpleCrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock3D") + return SimpleCrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "ResnetDownsampleBlock3D": + return ResnetDownsampleBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SimpleCrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock3D") + return SimpleCrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "ResnetUpsampleBlock3D": + return ResnetUpsampleBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, #todo: transformer_layers_per_block? + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class UNetMidBlock3DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + self.attention_head_dim, + in_channels // self.attention_head_dim, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs,) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, **ckpt_kwargs,) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, **ckpt_kwargs, + ).sample + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class ResnetDownsampleBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + skip_time_act=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_downsample=True, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + attention_head_dim, + out_channels // attention_head_dim, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + output_states = () + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + mask, + cross_attention_kwargs, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + num_frames: int = 1, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs,) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, **ckpt_kwargs,) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class ResnetUpsampleBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + skip_time_act=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class SimpleCrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + attention_head_dim, + out_channels // attention_head_dim, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + num_frames: int = 1, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + mask, + cross_attention_kwargs, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states \ No newline at end of file diff --git a/src/videogen_hub/pipelines/show_1/showone/models/unet_3d_condition.py b/src/videogen_hub/pipelines/show_1/showone/models/unet_3d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..6b12948afb0caa96a131f097379ffcf0b81d01f0 --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/showone/models/unet_3d_condition.py @@ -0,0 +1,985 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +# from diffusers.models.transformer_temporal import TransformerTemporalModel +from .transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UNetMidBlock3DSimpleCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + transfromer_in_opt: bool =False, + ): + super().__init__() + + self.sample_size = sample_size + self.transformer_in_opt = transfromer_in_opt + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + if self.transformer_in_opt: + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=64, + in_channels=block_out_channels[0], + num_layers=1, + ) + + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock3DSimpleCrossAttn": + self.mid_block = UNetMidBlock3DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + # count = len(self.attn_processors.keys()) + # ignore temporal attention + count = len({k: v for k, v in self.attn_processors.items() if "temp_" not in k}.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor") and "temp_" not in name: + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + if self.transformer_in_opt: + + sample = self.transformer_in( + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): + import os, json + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + + config["down_block_types"] = [x.replace("2D", "3D") for x in config["down_block_types"]] + if "mid_block_type" in config.keys(): + config["mid_block_type"] = config["mid_block_type"].replace("2D", "3D") + config["up_block_types"] = [x.replace("2D", "3D") for x in config["up_block_types"]] + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + for k, v in model.state_dict().items(): + if k not in state_dict: + + state_dict.update({k: v}) + model.load_state_dict(state_dict) + + return model diff --git a/src/videogen_hub/pipelines/show_1/showone/pipelines/__init__.py b/src/videogen_hub/pipelines/show_1/showone/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..293b79099c1c29138dba96164a6c45bc09b1087f --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/showone/pipelines/__init__.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import torch + +from diffusers.utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available + + +@dataclass +class TextToVideoPipelineOutput(BaseOutput): + """ + Output class for text to video pipelines. + + Args: + frames (`List[np.ndarray]` or `torch.FloatTensor`) + List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as + a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list + denotes the video length i.e., the number of frames. + """ + + frames: Union[List[np.ndarray], torch.FloatTensor] + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from diffusers.utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + # from .pipeline_t2v_base_latent import TextToVideoSDPipeline # noqa: F401 + # from .pipeline_t2v_base_latent_sdxl import TextToVideoSDXLPipeline + from .pipeline_t2v_base_pixel import TextToVideoIFPipeline + from .pipeline_t2v_interp_pixel import TextToVideoIFInterpPipeline + # from .pipeline_t2v_sr_latent import TextToVideoSDSuperResolutionPipeline + from .pipeline_t2v_sr_pixel import TextToVideoIFSuperResolutionPipeline + # from .pipeline_t2v_base_latent_controlnet import TextToVideoSDControlNetPipeline diff --git a/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_base_pixel.py b/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_base_pixel.py new file mode 100644 index 0000000000000000000000000000000000000000..4c9c5637fc9aee871359920b10b6a157e37c22b4 --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_base_pixel.py @@ -0,0 +1,775 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from diffusers.loaders import LoraLoaderMixin +from diffusers.schedulers import DDPMScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from ..models import UNet3DConditionModel +from . import TextToVideoPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +class TextToVideoIFPipeline(DiffusionPipeline, LoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet3DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + # safety_checker: Optional[IFSafetyChecker] + + # watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet3DConditionModel, + scheduler: DDPMScheduler, + feature_extractor: Optional[CLIPImageProcessor], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.safety_checker = None + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + self.unet.train() + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator): + shape = (batch_size, num_channels, num_frames, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare intermediate images + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + self.unet.config.in_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # reshape latents + bsz, channel, frames, height, width = intermediate_images.shape + intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + # reshape latents back + intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + video_tensor = intermediate_images + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return TextToVideoPipelineOutput(frames=video) diff --git a/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_interp_pixel.py b/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_interp_pixel.py new file mode 100644 index 0000000000000000000000000000000000000000..8186ba8892bf611ace2255576030e5bfda3e43dd --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_interp_pixel.py @@ -0,0 +1,798 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from diffusers.schedulers import DDPMScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from ..models import UNet3DConditionModel +from . import TextToVideoPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +class TextToVideoIFInterpPipeline(DiffusionPipeline): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet3DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + # safety_checker: Optional[IFSafetyChecker] + + # watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet3DConditionModel, + scheduler: DDPMScheduler, + feature_extractor: Optional[CLIPImageProcessor], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.safety_checker = None + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator): + shape = (batch_size, num_channels, num_frames, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + def __call__( + self, + pixel_values, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + init_noise = None, + cond_interpolation = False, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare intermediate images + pixel_values = pixel_values.to(device) + if init_noise is not None: + intermediate_images = init_noise + else: + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + # self.unet.config.in_channels, # mask not noise. + pixel_values.shape[1], + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + bsz = intermediate_images.shape[0] + interp_mask = torch.zeros(bsz, 1, *intermediate_images.shape[2:], device=device, dtype=intermediate_images.dtype) + interp_mask[:, :, 0, :, :] = 1 + interp_mask[:, :, -1, :, :] = 1 + + if cond_interpolation: + import torch.nn.functional as F + pixel_values = F.interpolate(pixel_values[:, :, [0, -1], ...], pixel_values.shape[2:], + mode="trilinear", align_corners=True) + else: + raise Exception("apply mask to pixel_values") + + # intermediate_images[:, :, 0, :, :] = pixel_values[:, :, 0, :, :] + # intermediate_images[:, :, -1, :, :] = pixel_values[:, :, -1, :, :] + pixel_values_condition = torch.cat((pixel_values, interp_mask), dim=1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + intermediate_images_input = torch.cat((intermediate_images, pixel_values_condition), dim=1) + model_input = ( + torch.cat([intermediate_images_input] * 2) if do_classifier_free_guidance else intermediate_images_input + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(intermediate_images.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(intermediate_images.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # reshape latents + bsz, channel, frames, width, height = intermediate_images.shape + intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + # reshape latents back + intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + video_tensor = intermediate_images + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return TextToVideoPipelineOutput(frames=video) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_sr_pixel.py b/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_sr_pixel.py new file mode 100644 index 0000000000000000000000000000000000000000..bdbf6291dd86c7f1b9e665bf707dc8bd0df5eb8e --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_sr_pixel.py @@ -0,0 +1,877 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +from einops import rearrange +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from diffusers.loaders import LoraLoaderMixin +from diffusers.schedulers import DDPMScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from ..models import UNet3DConditionModel +from . import TextToVideoPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +class TextToVideoIFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet3DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + # safety_checker: Optional[IFSafetyChecker] + + # watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet3DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + feature_extractor: Optional[CLIPImageProcessor], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + feature_extractor=feature_extractor, + ) + self.safety_checker = None + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})" + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator): + shape = (batch_size, num_channels, num_frames, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def preprocess_image(self, image, num_images_per_prompt, device): + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 255.0 for i in image] + + image = np.stack(image, axis=0) # to np + torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 20, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`): + The image to be upscaled. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + assert isinstance(image, torch.Tensor), f"{type(image)} is not supported." + num_frames = image.shape[2] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare intermediate images + num_channels = self.unet.config.in_channels // 2 + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + num_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare upscaled image and noise level + image = self.preprocess_image(image, num_images_per_prompt, device) + upscaled = rearrange(image, "b c f h w -> (b f) c h w") + upscaled = F.interpolate(upscaled, (height, width), mode="bilinear", align_corners=True) + upscaled = rearrange(upscaled, "(b f) c h w -> b c f h w", f=image.shape[2]) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + # reshape latents + bsz, channel, frames, height, width = intermediate_images.shape + intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + # reshape latents back + intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + video_tensor = intermediate_images + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return TextToVideoPipelineOutput(frames=video) diff --git a/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_sr_pixel_cond.py b/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_sr_pixel_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..3122ff6bc1d4796689167759021903b6f4eec7ef --- /dev/null +++ b/src/videogen_hub/pipelines/show_1/showone/pipelines/pipeline_t2v_sr_pixel_cond.py @@ -0,0 +1,890 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from einops import rearrange + +from diffusers.loaders import LoraLoaderMixin +from diffusers.schedulers import DDPMScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from ..models import UNet3DConditionModel +from . import TextToVideoPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +class TextToVideoIFSuperResolutionPipeline_Cond(DiffusionPipeline, LoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet3DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + # safety_checker: Optional[IFSafetyChecker] + + # watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet3DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + feature_extractor: Optional[CLIPImageProcessor], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + feature_extractor=feature_extractor, + ) + self.safety_checker = None + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})" + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator): + shape = (batch_size, num_channels, num_frames, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def preprocess_image(self, image, num_images_per_prompt, device): + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 255.0 for i in image] + + image = np.stack(image, axis=0) # to np + torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, + first_frame_cond: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, + all_frame_cond: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 250, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`): + The image to be upscaled. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + assert isinstance(image, torch.Tensor), f"{type(image)} is not supported." + num_frames = image.shape[2] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare intermediate images + num_channels = self.unet.config.in_channels // 2 + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + num_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare upscaled image and noise level + image = self.preprocess_image(image, num_images_per_prompt, device) + # upscaled = F.interpolate(image, (num_frames, height, width), mode="trilinear", align_corners=True) + if all_frame_cond is not None: + upscaled = all_frame_cond + else: + upscaled = rearrange(image, "b c f h w -> (b f) c h w") + upscaled = F.interpolate(upscaled, (height, width), mode="bilinear", align_corners=True) + upscaled = rearrange(upscaled, "(b f) c h w -> b c f h w", f=image.shape[2]) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + if first_frame_cond is not None: + first_frame_cond = first_frame_cond.to(device=device, dtype=self.unet.dtype) + upscaled[:,:,:1,:,:] = first_frame_cond + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # reshape latents + bsz, channel, frames, height, width = intermediate_images.shape + intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + # reshape latents back + intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + video_tensor = intermediate_images + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return TextToVideoPipelineOutput(frames=video) diff --git a/src/videogen_hub/pipelines/streamingt2v/__init__.py b/src/videogen_hub/pipelines/streamingt2v/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0db7212eb5715c46830e735abc29da17bcdcf8bd --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/__init__.py @@ -0,0 +1,8 @@ +from pathlib import Path + + +WORK_DIR = Path(__file__).resolve().parent + +import sys + +sys.path.insert(0, "./src/videogen_hub/pipelines/streamingt2v") diff --git a/src/videogen_hub/pipelines/streamingt2v/configs/__init__.py b/src/videogen_hub/pipelines/streamingt2v/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/configs/inference/__init__.py b/src/videogen_hub/pipelines/streamingt2v/configs/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/configs/inference/inference_long_video.yaml b/src/videogen_hub/pipelines/streamingt2v/configs/inference/inference_long_video.yaml new file mode 100644 index 0000000000000000000000000000000000000000..618c5d6f2225e82a00f7cf5699c5515c4af08cf5 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/configs/inference/inference_long_video.yaml @@ -0,0 +1,37 @@ +trainer: + devices: '1' + num_nodes: 1 +model: + inference_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.InferenceParams + init_args: + num_inference_steps: 50 # number of inference steps + frame_rate: 8 + eta: 1.0 # eta used for DDIM sampler + guidance_scale: 7.5 # classifier free guidance scale + conditioning_type: fixed + start_from_real_input: false + n_autoregressive_generations: 6 # how many autoregressive generations + scheduler_cls: '' # we can load other models + unet_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.UNetParams + init_args: + use_standard_attention_processor: False + opt_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.OptimizerParams + init_args: + noise_generator: + class_path: t2v_enhanced.model.video_noise_generator.NoiseGenerator + init_args: + mode: vanilla # can be 'vanilla','mixed_noise', 'consistI2V' or 'mixed_noise_consistI2V' + alpha: 1.0 + shared_noise_across_chunks: True # if true, shared noise between all chunks of a video + forward_steps: 850 # number of DDPM forward steps + radius: [2,2,2] # radius for time, width and height +n_predictions: 300 +data: + class_path: t2v_enhanced.model.datasets.prompt_reader.PromptReader + init_args: + prompt_cfg: + type: file + content: /home/roberto.henschel/T2V-Enhanced/repo/training_code/t2v_enhanced/evaluation_prompts/prompts_long_eval.txt diff --git a/src/videogen_hub/pipelines/streamingt2v/configs/text_to_video/__init__.py b/src/videogen_hub/pipelines/streamingt2v/configs/text_to_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/configs/text_to_video/config.yaml b/src/videogen_hub/pipelines/streamingt2v/configs/text_to_video/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..61902c194cced8f537d46193a153a91ba5b07a04 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/configs/text_to_video/config.yaml @@ -0,0 +1,227 @@ +# pytorch_lightning==2.0.9 +seed_everything: 33 +trainer: + accelerator: auto + strategy: auto + devices: '8' + num_nodes: 1 + # precision: 16-mixed + logger: null + callbacks: + - class_path: pytorch_lightning.callbacks.RichModelSummary + init_args: + max_depth: 1 + - class_path: pytorch_lightning.callbacks.RichProgressBar + init_args: + refresh_rate: 1 + leave: false + theme: + description: white + progress_bar: '#6206E0' + progress_bar_finished: '#6206E0' + progress_bar_pulse: '#6206E0' + batch_progress: white + time: grey54 + processing_speed: grey70 + metrics: white + console_kwargs: null + fast_dev_run: false + max_epochs: 5000 + min_epochs: null + max_steps: 2020000 + min_steps: null + max_time: null + limit_train_batches: null + limit_val_batches: 512 + limit_test_batches: null + limit_predict_batches: null + overfit_batches: 0.0 + val_check_interval: 8000 + check_val_every_n_epoch: 1 + num_sanity_val_steps: null + log_every_n_steps: 10 + enable_checkpointing: null + enable_progress_bar: null + enable_model_summary: null + accumulate_grad_batches: 8 + gradient_clip_val: 1 + gradient_clip_algorithm: norm + deterministic: null + benchmark: null + inference_mode: true + use_distributed_sampler: true + profiler: null + detect_anomaly: false + barebones: false + plugins: null + sync_batchnorm: false + reload_dataloaders_every_n_epochs: 0 + default_root_dir: null +model: + inference_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.InferenceParams + init_args: + width: 256 + height: 256 + video_length: 16 + guidance_scale: 7.5 + use_dec_scaling: true + frame_rate: 8 + num_inference_steps: 50 + eta: 1.0 + n_autoregressive_generations: 1 + mode: long_video + start_from_real_input: true + eval_loss_metrics: false + scheduler_cls: '' + negative_prompt: '' + conditioning_from_all_past: false + validation_samples: 80 + conditioning_type: last_chunk + result_formats: + - eval_gif + - gif + - mp4 + concat_video: true + opt_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.OptimizerParams + init_args: + learning_rate: 5.0e-05 + layers_config: + class_path: t2v_enhanced.model.requires_grad_setter.LayerConfig + init_args: + gradient_setup: + - - false + - - vae + - - false + - - text_encoder + - - false + - - image_encoder + - - true + - - resampler + - - true + - - unet + - - true + - - base_model + - - false + - - base_model + - transformer_in + - - false + - - base_model + - temp_attentions + - - false + - - base_model + - temp_convs + layers_config_base: null + use_warmup: false + warmup_steps: 10000 + warmup_start_factor: 1.0e-05 + learning_rate_spatial: 0.0 + use_8_bit_adam: false + noise_generator: null + noise_decomposition: null + perceptual_loss: false + noise_offset: 0.0 + split_opt_by_node: false + reset_prediction_type_to_eps: false + train_val_sampler_may_differ: true + measure_similarity: false + similarity_loss: false + similarity_loss_weight: 1.0 + loss_conditional_weight: 0.0 + loss_conditional_weight_convex: false + loss_conditional_change_after_step: 0 + mask_conditional_frames: false + sample_from_noise: true + mask_alternating: false + uncondition_freq: -1 + no_text_condition_control: false + inject_image_into_input: false + inject_at_T: false + resampling_steps: 1 + control_freq_in_resample: 1 + resample_to_T: false + adaptive_loss_reweight: false + load_resampler_from_ckpt: '' + skip_controlnet_branch: false + use_fps_conditioning: false + num_frame_embeddings_range: 16 + start_frame_training: 16 + start_frame_ctrl: 16 + load_trained_base_model_and_resampler_from_ckpt: '' + load_trained_controlnet_from_ckpt: '' + unet_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.UNetParams + init_args: + conditioning_embedding_out_channels: + - 32 + - 96 + - 256 + - 512 + ckpt_spatial_layers: '' + pipeline_repo: damo-vilab/text-to-video-ms-1.7b + unet_from_diffusers: true + spatial_latent_input: false + num_frame_conditioning: 1 + pipeline_class: t2v_enhanced.model.model.controlnet.pipeline_text_to_video_w_controlnet_synth.TextToVideoSDPipeline + frame_expansion: none + downsample_controlnet_cond: true + num_frames: 16 + pre_transformer_in_cond: false + num_tranformers: 1 + zero_conv_3d: false + merging_mode: addition + compute_only_conditioned_frames: false + condition_encoder: '' + zero_conv_mode: Identity + clean_model: true + merging_mode_base: attention_cross_attention + attention_mask_params: null + attention_mask_params_base: null + modelscope_input_format: true + temporal_self_attention_only_on_conditioning: false + temporal_self_attention_mask_included_itself: false + use_post_merger_zero_conv: false + weight_control_sample: 1.0 + use_controlnet_mask: false + random_mask_shift: false + random_mask: false + use_resampler: true + unet_from_pipe: false + unet_operates_on_2d: false + image_encoder: CLIP + use_standard_attention_processor: false + num_frames_before_chunk: 0 + resampler_type: single_frame + resampler_cls: t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder.ImgEmbContextResampler + resampler_merging_layers: 4 + image_encoder_obj: + class_path: t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder.FrozenOpenCLIPImageEmbedder + init_args: + arch: ViT-H-14 + version: laion2b_s32b_b79k + device: cuda + max_length: 77 + freeze: true + antialias: true + ucg_rate: 0.0 + unsqueeze_dim: false + repeat_to_max_len: false + num_image_crops: 0 + output_tokens: false + cfg_text_image: false + aggregation: last_out + resampler_random_shift: true + img_cond_alpha_per_frame: false + num_control_input_frames: 8 + use_image_encoder_normalization: false + use_of: false + ema_param: -1.0 + concat: false + use_image_tokens_main: true + use_image_tokens_ctrl: false +result_fol: results +exp_name: my_exp_name +run_name: my_run_name +scale_lr: false +matmul_precision: high diff --git a/src/videogen_hub/pipelines/streamingt2v/inference.py b/src/videogen_hub/pipelines/streamingt2v/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..69cdbc60933befb2da043f5db86f68feda73cc2d --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/inference.py @@ -0,0 +1,214 @@ +# General +import os +from os.path import join as opj +import argparse +import datetime +from pathlib import Path +import torch +import gradio as gr +import tempfile +import yaml +from model.video_ldm import VideoLDM +from typing import List, Optional +from model.callbacks import SaveConfigCallback +from PIL.Image import Image, fromarray + +from einops import rearrange, repeat + +import sys + +from videogen_hub import MODEL_PATH + +sys.path.append("thirdparty") +from modelscope.pipelines import pipeline +from modelscope.outputs import OutputKeys +import imageio +import pathlib +import numpy as np + +# Utilities +from inference_utils import * +from model_init import * +from model_func import * + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt", + type=str, + default="A cat running on the street", + help="The prompt to guide video generation.", + ) + parser.add_argument( + "--image", type=str, default="", help="Path to image conditioning." + ) + # parser.add_argument('--video', type=str, default="", help="Path to video conditioning.") + parser.add_argument( + "--base_model", + type=str, + default="ModelscopeT2V", + help="Base model to generate first chunk from", + choices=["ModelscopeT2V", "AnimateDiff", "SVD"], + ) + parser.add_argument( + "--num_frames", + type=int, + default=24, + help="The number of video frames to generate.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="", + help="The prompt to guide what to not include in video generation.", + ) + parser.add_argument( + "--negative_prompt_enhancer", + type=str, + default=None, + help="The prompt to guide what to not include in video enhancement. " + "By default is the same as --negative_prompt", + ) + parser.add_argument( + "--num_steps", type=int, default=50, help="The number of denoising steps." + ) + parser.add_argument( + "--image_guidance", type=float, default=9.0, help="The guidance scale." + ) + + parser.add_argument( + "--output_dir", + type=str, + default="results", + help="Path where to save the generated videos.", + ) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--seed", type=int, default=33, help="Random seed") + + parser.add_argument( + "--chunk", type=int, default=24, help="chunk_size for randomized blending" + ) + parser.add_argument( + "--overlap", type=int, default=8, help="overlap_size for randomized blending" + ) + + parser.add_argument( + "--offload_models", + action="store_true", + help="Load/Offload models to gpu/cpu before and after inference", + ) + args = parser.parse_args() + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + result_fol = Path(args.output_dir).absolute() + device = args.device + + # -------------------------- + # ----- Configurations ----- + # -------------------------- + ckpt_file_streaming_t2v = os.path.join(MODEL_PATH, "streaming_t2v.ckpt") + cfg_v2v = { + "downscale": 1, + "upscale_size": (1280, 720), + "model_id": "damo/Video-to-Video", + "pad": True, + } + + # -------------------------- + # ----- Initialization ----- + # -------------------------- + if args.base_model == "ModelscopeT2V": + if args.offload_models: + model = init_modelscope("cpu") + else: + model = init_modelscope(device) + elif args.base_model == "AnimateDiff": + if args.offload_models: + model = init_animatediff("cpu") + else: + model = init_animatediff(device) + elif args.base_model == "SVD": + if args.offload_models: + model = init_svd("cpu") + sdxl_model = init_sdxl("cpu") + else: + model = init_svd(device) + sdxl_model = init_sdxl(device) + + if args.offload_models: + msxl_model = init_v2v_model(cfg_v2v, "cpu") + else: + msxl_model = init_v2v_model(cfg_v2v, device) + + stream_cli, stream_model = init_streamingt2v_model( + ckpt_file_streaming_t2v, result_fol, "cuda" + ) + if args.offload_models: + stream_model = st2v_to_device(stream_model, "cpu") + inference_generator = torch.Generator(device="cuda") + + # ------------------ + # ----- Inputs ----- + # ------------------ + now = datetime.datetime.now() + name = ( + args.prompt[:100].replace(" ", "_") + + "_" + + str(now.time()).replace(":", "_").replace(".", "_") + ) + + inference_generator = torch.Generator(device="cuda") + inference_generator.manual_seed(args.seed) + + if args.offload_models: + model = model.to(device) + if args.base_model == "ModelscopeT2V": + short_video = ms_short_gen(args.prompt, model, inference_generator) + elif args.base_model == "AnimateDiff": + short_video = ad_short_gen(args.prompt, model, inference_generator) + elif args.base_model == "SVD": + if args.offload_models: + sdxl_model = sdxl_model.to(device) + short_video = svd_short_gen( + args.image, args.prompt, model, sdxl_model, inference_generator + ) + if args.offload_models: + sdxl_model = sdxl_model.to("cpu") + if args.offload_models: + model = model.to("cpu") + + n_autoreg_gen = (args.num_frames - 8) // 8 + stream_long_gen( + args.prompt, + short_video, + n_autoreg_gen, + args.negative_prompt, + args.seed, + args.num_steps, + args.image_guidance, + name, + stream_cli, + stream_model, + ) + if args.offload_models: + stream_model = st2v_to_device(stream_model, "cpu") + + args.negative_prompt_enhancer = ( + args.negative_prompt_enhancer + if args.negative_prompt_enhancer is not None + else args.negative_prompt + ) + if args.offload_models: + msxl_model = v2v_to_device(msxl_model, device) + video2video_randomized( + args.prompt, + opj(result_fol, name + ".mp4"), + result_fol, + cfg_v2v, + msxl_model, + chunk_size=args.chunk, + overlap_size=args.overlap, + negative_prompt=args.negative_prompt_enhancer, + ) + if args.offload_models: + msxl_model = v2v_to_device(msxl_model, "cpu") diff --git a/src/videogen_hub/pipelines/streamingt2v/inference_utils.py b/src/videogen_hub/pipelines/streamingt2v/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..55eeff679a928f73721bd3ba1c7814330d7cefc8 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/inference_utils.py @@ -0,0 +1,125 @@ +# import argparse +import sys +from pathlib import Path +from pytorch_lightning.cli import LightningCLI +from PIL import Image + +# For streaming +import yaml +from copy import deepcopy +from typing import List, Optional +from jsonargparse.typing import restricted_string_type + + +# -------------------------------------- +# ----------- For Streaming ------------ +# -------------------------------------- +class CustomCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_argument("--result_fol", type=Path, + help="Set the path to the result folder", default="results") + parser.add_argument("--exp_name", type=str, help="Experiment name") + parser.add_argument("--run_name", type=str, + help="Current run name") + parser.add_argument("--prompts", type=Optional[List[str]]) + parser.add_argument("--scale_lr", type=bool, + help="Scale lr", default=False) + CodeType = restricted_string_type( + 'CodeType', '(medium)|(high)|(highest)') + parser.add_argument("--matmul_precision", type=CodeType) + parser.add_argument("--ckpt", type=Path,) + parser.add_argument("--n_predictions", type=int) + return parser + +def remove_value(dictionary, x): + for key, value in list(dictionary.items()): + if key == x: + del dictionary[key] + elif isinstance(value, dict): + remove_value(value, x) + return dictionary + +def legacy_transformation(cfg: yaml): + cfg = deepcopy(cfg) + cfg["trainer"]["devices"] = "1" + cfg["trainer"]['num_nodes'] = 1 + + if not "class_path" in cfg["model"]["inference_params"]: + cfg["model"]["inference_params"] = { + "class_path": "t2v_enhanced.model.pl_module_params.InferenceParams", "init_args": cfg["model"]["inference_params"]} + return cfg + + +# --------------------------------------------- +# ----------- For enhancement ----------- +# --------------------------------------------- +def add_margin(pil_img, top, right, bottom, left, color): + width, height = pil_img.size + new_width = width + right + left + new_height = height + top + bottom + result = Image.new(pil_img.mode, (new_width, new_height), color) + result.paste(pil_img, (left, top)) + return result + +def resize_to_fit(image, size): + W, H = size + w, h = image.size + if H / h > W / w: + H_ = int(h * W / w) + W_ = W + else: + W_ = int(w * H / h) + H_ = H + return image.resize((W_, H_)) + +def pad_to_fit(image, size): + W, H = size + w, h = image.size + pad_h = (H - h) // 2 + pad_w = (W - w) // 2 + return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0)) + +def resize_and_keep(pil_img): + myheight = 576 + hpercent = (myheight/float(pil_img.size[1])) + wsize = int((float(pil_img.size[0])*float(hpercent))) + pil_img = pil_img.resize((wsize, myheight)) + return pil_img + +def center_crop(pil_img): + width, height = pil_img.size + new_width = 576 + new_height = 576 + + left = (width - new_width)/2 + top = (height - new_height)/2 + right = (width + new_width)/2 + bottom = (height + new_height)/2 + + # Crop the center of the image + pil_img = pil_img.crop((left, top, right, bottom)) + return pil_img + + +def v2v_to_device(pipe_enhance, device): + pipe_enhance.device = device + + pipe_enhance.model = pipe_enhance.model.to(device) + pipe_enhance.model.device = device + + pipe_enhance.model.clip_encoder.model = pipe_enhance.model.clip_encoder.model.to(device) + pipe_enhance.model.clip_encoder.device = device + + pipe_enhance.model.autoencoder = pipe_enhance.model.autoencoder.to(device) + pipe_enhance.model.generator = pipe_enhance.model.generator.to(device) + if device.startswith("cuda"): + pipe_enhance.model.generator = pipe_enhance.model.generator.half() + pipe_enhance.model.negative_y = pipe_enhance.model.negative_y.to(device) + return pipe_enhance + +def st2v_to_device(stream_model, device): + stream_model = stream_model.to(device) + stream_model.inference_pipeline.unet = stream_model.inference_pipeline.unet.to(device) + stream_model.inference_pipeline.vae = stream_model.inference_pipeline.vae.to(device) + stream_model.inference_pipeline = stream_model.inference_pipeline.to(device) + return stream_model \ No newline at end of file diff --git a/src/videogen_hub/pipelines/streamingt2v/model/__init__.py b/src/videogen_hub/pipelines/streamingt2v/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/model/callbacks.py b/src/videogen_hub/pipelines/streamingt2v/model/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..85f4814114dab8bcfde8afefa2b127cd570dd080 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/callbacks.py @@ -0,0 +1,102 @@ + +from pathlib import Path +from pytorch_lightning import Callback +import os +import torch +from lightning_fabric.utilities.cloud_io import get_filesystem +from pytorch_lightning.cli import LightningArgumentParser +from pytorch_lightning import LightningModule, Trainer +from lightning_utilities.core.imports import RequirementCache +from omegaconf import OmegaConf + +_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache( + "jsonargparse[signatures]>=4.17.0") + +if _JSONARGPARSE_SIGNATURES_AVAILABLE: + import docstring_parser + from jsonargparse import ( + ActionConfigFile, + ArgumentParser, + class_from_function, + Namespace, + register_unresolvable_import_paths, + set_config_read_mode, + ) + + # Required until fix https://github.com/pytorch/pytorch/issues/74483 + register_unresolvable_import_paths(torch) + set_config_read_mode(fsspec_enabled=True) +else: + locals()["ArgumentParser"] = object + locals()["Namespace"] = object + + +class SaveConfigCallback(Callback): + """Saves a LightningCLI config to the log_dir when training starts. + + Args: + parser: The parser object used to parse the configuration. + config: The parsed configuration that will be saved. + config_filename: Filename for the config file. + overwrite: Whether to overwrite an existing config file. + multifile: When input is multiple config files, saved config preserves this structure. + + Raises: + RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run + """ + + def __init__( + self, + parser: LightningArgumentParser, + config: Namespace, + log_dir: str, + config_filename: str = "config.yaml", + overwrite: bool = False, + multifile: bool = False, + + ) -> None: + self.parser = parser + self.config = config + self.config_filename = config_filename + self.overwrite = overwrite + self.multifile = multifile + self.already_saved = False + self.log_dir = log_dir + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + if self.already_saved: + return + + log_dir = self.log_dir + assert log_dir is not None + config_path = os.path.join(log_dir, self.config_filename) + fs = get_filesystem(log_dir) + + if not self.overwrite: + # check if the file exists on rank 0 + file_exists = fs.isfile( + config_path) if trainer.is_global_zero else False + # broadcast whether to fail to all ranks + file_exists = trainer.strategy.broadcast(file_exists) + if file_exists: + raise RuntimeError( + f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" + " results of a previous run. You can delete the previous config file," + " set `LightningCLI(save_config_callback=None)` to disable config saving," + ' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.' + ) + + # save the file on rank 0 + if trainer.is_global_zero: + # save only on rank zero to avoid race conditions. + # the `log_dir` needs to be created as we rely on the logger to do it usually + # but it hasn't logged anything at this point + fs.makedirs(log_dir, exist_ok=True) + self.parser.save( + self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile + ) + self.already_saved = True + trainer.logger.log_hyperparams(OmegaConf.load(config_path)) + + # broadcast so that all ranks are in sync on future calls to .setup() + self.already_saved = trainer.strategy.broadcast(self.already_saved) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/datasets/__init__.py b/src/videogen_hub/pipelines/streamingt2v/model/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/model/datasets/prompt_reader.py b/src/videogen_hub/pipelines/streamingt2v/model/datasets/prompt_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..4891617bbe121b9da52528c8f5c56224e1ca4cbf --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/datasets/prompt_reader.py @@ -0,0 +1,80 @@ +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.utilities.types import EVAL_DATALOADERS + +from t2v_enhanced.model.datasets.video_dataset import Annotations +import json + + +class ConcatDataset(torch.utils.data.Dataset): + def __init__(self, datasets): + self.datasets = datasets + self.model_id = datasets["reconstruction_dataset"].model_id + + def __getitem__(self, idx): + sample = {ds: self.datasets[ds].__getitem__( + idx) for ds in self.datasets} + return sample + + def __len__(self): + return min(len(self.datasets[d]) for d in self.datasets) + + +class CustomPromptsDataset(torch.utils.data.Dataset): + + def __init__(self, prompt_cfg: Dict[str, str]): + super().__init__() + + if prompt_cfg["type"] == "prompt": + self.prompts = [prompt_cfg["content"]] + elif prompt_cfg["type"] == "file": + file = Path(prompt_cfg["content"]) + if file.suffix == ".npy": + self.prompts = np.load(file.as_posix()) + elif file.suffix == ".txt": + with open(prompt_cfg["content"]) as f: + lines = [line.rstrip() for line in f] + self.prompts = lines + elif file.suffix == ".json": + with open(prompt_cfg["content"],"r") as file: + metadata = json.load(file) + if "videos_root" in prompt_cfg: + videos_root = Path(prompt_cfg["videos_root"]) + video_path = [str(videos_root / sample["page_dir"] / + f"{sample['videoid']}.mp4") for sample in metadata] + else: + video_path = [str(sample["page_dir"] / + f"{sample['videoid']}.mp4") for sample in metadata] + self.prompts = [sample["prompt"] for sample in metadata] + self.video_path = video_path + + + + + transformed_prompts = [] + for prompt in self.prompts: + transformed_prompts.append( + Annotations.clean_prompt(prompt)) + self.prompts = transformed_prompts + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, index): + output = {"prompt": self.prompts[index]} + if hasattr(self,"video_path"): + output["video"] = self.video_path[index] + return output + + +class PromptReader(pl.LightningDataModule): + def __init__(self, prompt_cfg: Dict[str, str]): + super().__init__() + self.predict_dataset = CustomPromptsDataset(prompt_cfg) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + return torch.utils.data.DataLoader(self.predict_dataset, batch_size=1, pin_memory=False, shuffle=False, drop_last=False) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/datasets/video_dataset.py b/src/videogen_hub/pipelines/streamingt2v/model/datasets/video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..21c752460db8c8eebba2d4d89cf5f6fd8228b456 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/datasets/video_dataset.py @@ -0,0 +1,57 @@ +from tqdm import tqdm +from einops import repeat +from diffusers import DiffusionPipeline +from decord import VideoReader, cpu +import torchvision +import torch +import numpy as np +import decord +import albumentations as album +import math +import random +from abc import abstractmethod +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Union +from PIL import Image +import json +Image.MAX_IMAGE_PIXELS = None + +decord.bridge.set_bridge("torch") + +class Annotations(): + + def __init__(self, + annotation_cfg: Dict) -> None: + self.annotation_cfg = annotation_cfg + + # TODO find all special characters + + @staticmethod + def process_string(string): + for special_char in [".", ",", ":"]: + result = "" + i = 0 + while i < len(string): + if string[i] == special_char: + if i > 0 and i < len(string) - 1 and string[i-1].isalpha() and string[i+1].isalpha(): + result += special_char+" " + else: + result += special_char + else: + result += string[i] + i += 1 + string = result + string = result + return result + + @staticmethod + def clean_prompt(prompt): + prompt = " ".join(prompt.split()) + prompt = prompt.replace(" , ", ", ") + prompt = prompt.replace(" . ", ". ") + prompt = prompt.replace(" : ", ": ") + prompt = Annotations.process_string(prompt) + return prompt + # return " ".join(prompt.split()) + diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/__init__.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/__init__.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/__init__.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/attention.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7d9686b97b35b35348b13094a11a08cf52c485 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/attention.py @@ -0,0 +1,300 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils.import_utils import is_xformers_available + +# from diffusers.models.attention_processor import Attention +# from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention import Attention +from .attention_processor import Attention +from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings + +# from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + is_spatial_attention: bool = False, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + use_image_embedding: bool = False, + unet_params=None, + ): + super().__init__() + + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + is_spatial_attention=is_spatial_attention, + use_image_embedding=use_image_embedding, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=( + cross_attention_dim if not double_self_attention else None + ), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + is_spatial_attention=is_spatial_attention, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/attention_processor.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..758e3025d773b952b0089bfefcbe2a7b23250851 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/attention_processor.py @@ -0,0 +1,444 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from einops import repeat +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate, logging +from diffusers.utils.import_utils import is_xformers_available + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + is_spatial_attention: bool, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + processor: Optional["AttnProcessor"] = None, + use_image_embedding: bool = False, + unet_params=None, + ): + super().__init__() + inner_dim = dim_head * heads + self.cross_attention_mode = cross_attention_dim is not None + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.is_spatial_attention = is_spatial_attention + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.train_image_cond_weight = use_image_embedding + self.use_image_embedding = use_image_embedding + + self.scale = dim_head**-0.5 if scale_qk else 1.0 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + embed_dim = 93 + if self.cross_attention_mode and self.is_spatial_attention and self.use_image_embedding: + self.conv = torch.nn.Conv1d(embed_dim, 77, kernel_size=3, padding="same") + self.conv_ln = nn.LayerNorm(1024) + self.register_parameter("alpha", nn.Parameter(torch.tensor(0.))) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr( + F, "scaled_dot_product_attention") and scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ): + is_lora = hasattr(self, "processor") and isinstance( + self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) + ) + + if use_memory_efficient_attention_xformers: + if self.added_kv_proj_dim is not None: + # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + raise NotImplementedError( + "Memory efficient attention with `xformers` is currently not supported when" + " `self.added_kv_proj_dim` is defined." + ) + elif not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_lora: + processor = LoRAAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + else: + processor = AttnProcessor() + + self.set_processor(processor) + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + processor = AttnProcessor() + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, + head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor, out_dim=3): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, + head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape( + batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores(self, query, key, attention_mask=None): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): + if batch_size is None: + deprecate( + "batch_size=None", + "0.0.15", + ( + "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" + " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" + " `prepare_attention_mask` when preparing the attention_mask." + ), + ) + batch_size = 1 + + head_size = self.heads + if attention_mask is None: + return attention_mask + + if attention_mask.shape[-1] != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros( + padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad( + attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave( + head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states): + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + + + +class AttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + + +AttentionProcessor = Union[ + AttnProcessor2_0, +] diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/conditioning.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/conditioning.py new file mode 100644 index 0000000000000000000000000000000000000000..3ceda8475fb60dc71e03e33bfae7d7d4ceba334a --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/conditioning.py @@ -0,0 +1,133 @@ +import torch.nn as nn + +try: + from diffusers.models.transformer_temporal import ( + TransformerTemporalModel, + TransformerTemporalModelOutput, + ) +except: + from diffusers.models.transformers.transformer_temporal import TransformerTemporalModelOutput + from diffusers.models import TransformerTemporalModel + +from einops import rearrange +from diffusers.models.attention_processor import Attention + +# from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention +from .transformer_temporal_crossattention import ( + TransformerTemporalModel as TransformerTemporalModelCrossAttn, +) +import torch + + +class CrossAttention(nn.Module): + + def __init__(self, input_channels, attention_head_dim, norm_num_groups=32): + super().__init__() + self.attention = Attention( + query_dim=input_channels, + cross_attention_dim=input_channels, + heads=input_channels // attention_head_dim, + dim_head=attention_head_dim, + bias=False, + upcast_attention=False, + ) + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=input_channels, + eps=1e-6, + affine=True, + ) + self.proj_in = nn.Linear(input_channels, input_channels) + self.proj_out = nn.Linear(input_channels, input_channels) + + def forward(self, hidden_state, encoder_hidden_states, num_frames): + h, w = hidden_state.shape[2], hidden_state.shape[3] + hidden_state_norm = rearrange( + hidden_state, "(B F) C H W -> B C F H W", F=num_frames + ) + hidden_state_norm = self.norm(hidden_state_norm) + hidden_state_norm = rearrange(hidden_state_norm, "B C F H W -> (B H W) F C") + hidden_state_norm = self.proj_in(hidden_state_norm) + attn = self.attention( + hidden_state_norm, + encoder_hidden_states=encoder_hidden_states, + attention_mask=None, + ) + # proj_out + + residual = self.proj_out(attn) + + residual = rearrange(residual, "(B H W) F C -> (B F) C H W", H=h, W=w) + output = hidden_state + residual + return TransformerTemporalModelOutput(sample=output) + + +class ConditionalModel(nn.Module): + + def __init__(self, input_channels, conditional_model: str, attention_head_dim=64): + super().__init__() + num_layers = 1 + if "_layers_" in conditional_model: + config = conditional_model.split("_layers_") + conditional_model = config[0] + num_layers = int(config[1]) + + if conditional_model == "self_cross_transformer": + self.temporal_transformer = TransformerTemporalModel( + num_attention_heads=input_channels // attention_head_dim, + attention_head_dim=attention_head_dim, + in_channels=input_channels, + double_self_attention=False, + cross_attention_dim=input_channels, + ) + elif conditional_model == "cross_transformer": + self.temporal_transformer = TransformerTemporalModelCrossAttn( + num_attention_heads=input_channels // attention_head_dim, + attention_head_dim=attention_head_dim, + in_channels=input_channels, + double_self_attention=False, + cross_attention_dim=input_channels, + num_layers=num_layers, + ) + elif conditional_model == "cross_attention": + self.temporal_transformer = CrossAttention( + input_channels=input_channels, attention_head_dim=attention_head_dim + ) + elif conditional_model == "test_conv": + self.temporal_transformer = nn.Conv2d( + input_channels, input_channels, kernel_size=1 + ) + else: + raise NotImplementedError(f"mode {conditional_model} not implemented") + if conditional_model != "test_conv": + nn.init.zeros_(self.temporal_transformer.proj_out.weight) + nn.init.zeros_(self.temporal_transformer.proj_out.bias) + else: + nn.init.zeros_(self.temporal_transformer.weight) + nn.init.zeros_(self.temporal_transformer.bias) + self.conditional_model = conditional_model + + def forward(self, sample, conditioning, num_frames=None): + + assert conditioning.ndim == 5 + assert sample.ndim == 5 + if self.conditional_model != "test_conv": + conditioning = rearrange(conditioning, "B F C H W -> (B H W) F C") + + num_frames = sample.shape[1] + + sample = rearrange(sample, "B F C H W -> (B F) C H W") + + sample = self.temporal_transformer( + sample, encoder_hidden_states=conditioning, num_frames=num_frames + ).sample + + sample = rearrange(sample, "(B F) C H W -> B F C H W", F=num_frames) + else: + + conditioning = rearrange(conditioning, "B F C H W -> (B F) C H W") + f = sample.shape[1] + sample = rearrange(sample, "B F C H W -> (B F) C H W") + sample = sample + self.temporal_transformer(conditioning) + sample = rearrange(sample, "(B F) C H W -> B F C H W", F=f) + return sample diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/controlnet.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c49bd9494e47728875f3ffafbb0d99efb86b49d7 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/controlnet.py @@ -0,0 +1,960 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange, repeat + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor + +# from diffusers.models.transformer_temporal import TransformerTemporalModel +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + transformer_g_c, +) + +# from diffusers.models.unet_3d_condition import UNet3DConditionModel +from .unet_3d_condition import ( + UNet3DConditionModel, +) +from .transformer_temporal import ( + TransformerTemporalModel, +) +from videogen_hub.pipelines.streamingt2v.model.layers.conv_channel_extension import ( + Conv2D_SubChannels, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class Merger(nn.Module): + def __init__( + self, + n_frames_condition: int = 8, + n_frames_sample: int = 16, + merge_mode: str = "addition", + input_channels=0, + frame_expansion="last_frame", + ) -> None: + super().__init__() + self.merge_mode = merge_mode + self.n_frames_condition = n_frames_condition + self.n_frames_sample = n_frames_sample + self.frame_expansion = frame_expansion + + if merge_mode.startswith("attention"): + self.attention = ConditionalModel( + input_channels=input_channels, + conditional_model=merge_mode.split("attention_")[1], + ) + + def forward(self, x, condition_signal): + x = rearrange(x, "(B F) C H W -> B F C H W", F=self.n_frames_sample) + + condition_signal = rearrange( + condition_signal, "(B F) C H W -> B F C H W", B=x.shape[0] + ) + + if x.shape[1] - condition_signal.shape[1] > 0: + if self.frame_expansion == "last_frame": + fillup_latent = repeat( + condition_signal[:, -1], + "B C H W -> B F C H W", + F=x.shape[1] - condition_signal.shape[1], + ) + elif self.frame_expansion == "zero": + fillup_latent = torch.zeros( + ( + x.shape[0], + self.n_frames_sample - self.n_frames_condition, + *x.shape[2:], + ), + device=x.device, + dtype=x.dtype, + ) + + if self.frame_expansion != "none": + condition_signal = torch.cat([condition_signal, fillup_latent], dim=1) + + if self.merge_mode == "addition": + out = x + condition_signal + elif self.merge_mode.startswith("attention"): + out = self.attention(x, condition_signal) + out = rearrange(out, "B F C H W -> (B F) C H W") + return out + + +class ZeroConv(nn.Module): + def __init__( + self, channels: int, mode: str = "2d", num_frames: int = 8, zero_init=True + ): + super().__init__() + mode_parts = mode.split("_") + if len(mode_parts) > 1 and mode_parts[1] == "noinit": + zero_init = False + + if mode.startswith("2d"): + model = nn.Conv2d(channels, channels, kernel_size=1) + model = zero_module(model, reset=zero_init) + elif mode.startswith("3d"): + model = ZeroConv3D( + num_frames=num_frames, channels=channels, zero_init=zero_init + ) + elif mode == "Identity": + model = nn.Identity() + self.model = model + + def forward(self, x): + return self.model(x) + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + # TODO why not GAUSSIAN used? + # TODO why not 4x4 kernel? + # TODO why not 2 x2 stride? + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + downsample: bool = True, + final_3d_conv: bool = False, + num_frame_conditioning: int = 8, + num_frames: int = 16, + zero_init: bool = True, + use_controlnet_mask: bool = False, + use_normalization: bool = False, + ): + super().__init__() + self.num_frame_conditioning = num_frame_conditioning + self.num_frames = num_frames + self.final_3d_conv = final_3d_conv + self.conv_in = nn.Conv2d( + conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 + ) + if final_3d_conv: + print("USING 3D CONV in ControlNET") + + self.blocks = nn.ModuleList([]) + if use_normalization: + self.norms = nn.ModuleList([]) + self.use_normalization = use_normalization + + stride = 2 if downsample else 1 + if use_normalization: + res = 256 # HARD-CODED Resolution! + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append( + nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1) + ) + if use_normalization: + self.norms.append(nn.LayerNorm((channel_in, res, res))) + self.blocks.append( + nn.Conv2d( + channel_in, channel_out, kernel_size=3, padding=1, stride=stride + ) + ) + if use_normalization: + res = res // 2 + self.norms.append(nn.LayerNorm((channel_out, res, res))) + + if not final_3d_conv: + self.conv_out = zero_module( + nn.Conv2d( + block_out_channels[-1] + int(use_controlnet_mask), + conditioning_embedding_channels, + kernel_size=3, + padding=1, + ), + reset=zero_init, + ) + else: + self.conv_temp = zero_module( + TemporalConvLayer_Custom( + num_frame_conditioning, num_frames, dropout=0.0 + ), + reset=zero_init, + ) + self.conv_out = nn.Conv2d( + block_out_channels[-1] + int(use_controlnet_mask), + conditioning_embedding_channels, + kernel_size=3, + padding=1, + ) + # self.conv_temp = zero_module(nn.Conv3d( + # num_frame_conditioning, num_frames, kernel_size=3, padding=1) + # ) + + def forward(self, conditioning, vq_gan=None, controlnet_mask=None): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + if self.use_normalization: + for block, norm in zip(self.blocks, self.norms): + embedding = block(embedding) + embedding = norm(embedding) + embedding = F.silu(embedding) + else: + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + if controlnet_mask is not None: + embedding = rearrange( + embedding, "(B F) C H W -> F B C H W", F=self.num_frames + ) + controlnet_mask_expanded = controlnet_mask[:, :, None, None, None] + controlnet_mask_expanded = rearrange( + controlnet_mask_expanded, "B F C W H -> F B C W H" + ) + masked_embedding = controlnet_mask_expanded * embedding + embedding = rearrange(masked_embedding, "F B C H W -> (B F) C H W") + controlnet_mask_expanded = rearrange( + controlnet_mask_expanded, "F B C H W -> (B F) C H W" + ) + # controlnet_mask_expanded = repeat(controlnet_mask_expanded,"B C W H -> B (C x) W H",x=embedding.shape[1]) + controlnet_mask_expanded = repeat( + controlnet_mask_expanded, "B C W H -> B C (W y) H", y=embedding.shape[2] + ) + controlnet_mask_expanded = repeat( + controlnet_mask_expanded, "B C W H -> B C W (H z)", z=embedding.shape[3] + ) + + embedding = torch.cat([embedding, controlnet_mask_expanded], dim=1) + + embedding = self.conv_out(embedding) + if self.final_3d_conv: + # embedding = F.silu(embedding) + embedding = rearrange( + embedding, "(b f) c h w -> b f c h w", f=self.num_frame_conditioning + ) + embedding = self.conv_temp(embedding) + embedding = rearrange(embedding, "b f c h w -> (b f) c h w") + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + in_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + downsample_controlnet_cond: bool = True, + frame_expansion: str = "zero", + condition_encoder: str = "", + num_frames: int = 16, + num_frame_conditioning: int = 8, + num_tranformers: int = 1, + vae=None, + merging_mode: str = "addition", + zero_conv_mode: str = "2d", + use_controlnet_mask: bool = False, + use_image_embedding: bool = False, + use_image_encoder_normalization: bool = False, + unet_params=None, + ): + super().__init__() + self.gradient_checkpointing = False + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + self.use_image_tokens = unet_params.use_image_tokens_ctrl + self.image_encoder_name = type(unet_params.image_encoder).__name__ + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + """Conv2D_SubChannels + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + """ + self.conv_in = Conv2D_SubChannels( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + else: + self.class_embedding = None + conditioning_channels = 3 if downsample_controlnet_cond else 4 + # control net conditioning embedding + + if condition_encoder == "temp_conv_vq": + controlnet_cond_embedding = ControlNetConditioningEmbeddingVQ( + conditioning_embedding_channels=block_out_channels[0], + conditioning_channels=4, + block_out_channels=conditioning_embedding_out_channels, + downsample=False, + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + num_tranformers=num_tranformers, + # zero_init=not merging_mode.startswith("attention"), + ) + elif condition_encoder == "vq": + controlnet_cond_embedding = ControlNetConditioningOptVQ( + vq=vae, + conditioning_embedding_channels=block_out_channels[0], + conditioning_channels=4, + block_out_channels=conditioning_embedding_out_channels, + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + ) + + else: + controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + conditioning_channels=conditioning_channels, + block_out_channels=conditioning_embedding_out_channels, + downsample=downsample_controlnet_cond, + final_3d_conv=condition_encoder.endswith("3DConv"), + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + # zero_init=not merging_mode.startswith("attention") + use_controlnet_mask=use_controlnet_mask, + use_normalization=use_image_encoder_normalization, + ) + self.use_controlnet_mask = use_controlnet_mask + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + # conv_in + self.merger = Merger( + n_frames_sample=num_frames, + n_frames_condition=num_frame_conditioning, + merge_mode=merging_mode, + input_channels=block_out_channels[0], + frame_expansion=frame_expansion, + ) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + self.controlnet_down_blocks.append( + ZeroConv( + channels=output_channel, mode=zero_conv_mode, num_frames=num_frames + ) + ) + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + self.controlnet_down_blocks.append( + ZeroConv( + channels=output_channel, + mode=zero_conv_mode, + num_frames=num_frames, + ) + ) + + if not is_final_block: + self.controlnet_down_blocks.append( + ZeroConv( + channels=output_channel, + mode=zero_conv_mode, + num_frames=num_frames, + ) + ) + + # mid + mid_block_channel = block_out_channels[-1] + + self.controlnet_mid_block = ZeroConv( + channels=mid_block_channel, mode=zero_conv_mode, num_frames=num_frames + ) + + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + self.controlnet_cond_embedding = controlnet_cond_embedding + self.num_frames = num_frames + self.num_frame_conditioning = num_frame_conditioning + + @classmethod + def from_unet( + cls, + unet: UNet3DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + downsample_controlnet_cond: bool = True, + num_frames: int = 16, + num_frame_conditioning: int = 8, + frame_expansion: str = "zero", + num_tranformers: int = 1, + vae=None, + zero_conv_mode: str = "2d", + merging_mode: str = "addition", + # [spatial,spatial_3DConv,temp_conv_vq] + condition_encoder: str = "spatial_3DConv", + use_controlnet_mask: bool = False, + use_image_embedding: bool = False, + use_image_encoder_normalization: bool = False, + unet_params=None, + **kwargs, + ): + r""" + Instantiate Controlnet class from UNet3DConditionModel. + + Parameters: + unet (`UNet3DConditionModel`): + UNet model which weights are copied to the ControlNet. Note that all configuration options are also + copied where applicable. + """ + controlnet = cls( + in_channels=unet.config.in_channels, + down_block_types=unet.config.down_block_types, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + downsample_controlnet_cond=downsample_controlnet_cond, + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + frame_expansion=frame_expansion, + num_tranformers=num_tranformers, + vae=vae, + zero_conv_mode=zero_conv_mode, + merging_mode=merging_mode, + condition_encoder=condition_encoder, + use_controlnet_mask=use_controlnet_mask, + use_image_embedding=use_image_embedding, + use_image_encoder_normalization=use_image_encoder_normalization, + unet_params=unet_params, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.transformer_in.load_state_dict(unet.transformer_in.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict( + unet.class_embedding.state_dict() + ) + + controlnet.down_blocks.load_state_dict( + unet.down_blocks.state_dict(), strict=False + ) # can be that the controlnet model does not use image clip encoding + controlnet.mid_block.load_state_dict( + unet.mid_block.state_dict(), strict=False + ) + + return controlnet + + @property + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Parameters: + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `Attention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = ( + num_sliceable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D)): + module.gradient_checkpointing = value + + # TODO ADD WEIGHT CONTROL + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + weight_control: float = 1.0, + weight_control_sample: float = 1.0, + controlnet_mask: Optional[torch.Tensor] = None, + vq_gan=None, + ) -> Union[ControlNetOutput, Tuple]: + # check channel order + # TODO SET ATTENTION MASK And WEIGHT CONTROL as in CONTROLNET.PY + """ + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + """ + # assert controlnet_mask is None, "Controlnet Mask not implemented yet for clean model" + # 1. time + + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + sample = sample[:, :, : self.num_frames] + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + + if not self.use_image_tokens and encoder_hidden_states.shape[1] > 77: + encoder_hidden_states = encoder_hidden_states[:, :77] + + if encoder_hidden_states.shape[1] > 77: + # assert ( + # encoder_hidden_states.shape[1]-77) % num_frames == 0, f"Encoder shape {encoder_hidden_states.shape}. Num frames = {num_frames}" + context_text, context_img = ( + encoder_hidden_states[:, :77, :], + encoder_hidden_states[:, 77:, :], + ) + context_text = context_text.repeat_interleave(repeats=num_frames, dim=0) + + if self.image_encoder_name == "FrozenOpenCLIPImageEmbedder": + context_img = context_img.repeat_interleave(repeats=num_frames, dim=0) + else: + context_img = rearrange( + context_img, "b (t l) c -> (b t) l c", t=num_frames + ) + + encoder_hidden_states = torch.cat([context_text, context_img], dim=1) + else: + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + repeats=num_frames, dim=0 + ) + + # print(f"ctrl with tokens = {encoder_hidden_states.shape[1]}") + """ + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + repeats=num_frames, dim=0) + """ + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape( + (sample.shape[0] * num_frames, -1) + sample.shape[3:] + ) + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding( + controlnet_cond, vq_gan=vq_gan, controlnet_mask=controlnet_mask + ) + + if num_frames > 1: + if self.gradient_checkpointing: + sample = transformer_g_c(self.transformer_in, sample, num_frames) + else: + sample = self.transformer_in( + sample, num_frames=num_frames, attention_mask=attention_mask + ).sample + + sample = self.merger( + sample * weight_control_sample, weight_control * controlnet_cond + ) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, num_frames=num_frames + ) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip( + down_block_res_samples, self.controlnet_down_blocks + ): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + ( + down_block_res_sample, + ) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + # 0.1 to 1.0 + scales = torch.logspace( + -1, 0, len(down_block_res_samples) + 1, device=sample.device + ) + + scales = scales * conditioning_scale + down_block_res_samples = [ + sample * scale for sample, scale in zip(down_block_res_samples, scales) + ] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [ + sample * conditioning_scale for sample in down_block_res_samples + ] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) + for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean( + mid_block_res_sample, dim=(2, 3), keepdim=True + ) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, + mid_block_res_sample=mid_block_res_sample, + ) + + +def zero_module(module, reset=True): + if reset: + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/cross_attention.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..06389a35dd71dbc4679b4efc4e304799932e92b2 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/cross_attention.py @@ -0,0 +1,30 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils.import_utils import is_xformers_available +# from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/image_embedder.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/image_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..6d231067e6d61b2694d3a8507fad9023ff1af407 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/image_embedder.py @@ -0,0 +1,210 @@ +import math +from typing import Any, Mapping +import torch +import torch.nn as nn +import kornia + +# import open_clip +from transformers import AutoImageProcessor, AutoModel +from transformers.models.bit.image_processing_bit import BitImageProcessor +from einops import rearrange, repeat + +# FFN +# from mamba_ssm import Mamba + + +class ImgEmbContextResampler(nn.Module): + + def __init__( + self, + inner_dim=1280, + cross_attention_dim=1024, + expansion_factor=16, + **kwargs, + ): + super().__init__() + self.context_embedding = nn.Sequential( + nn.Linear(cross_attention_dim, inner_dim), + nn.SiLU(), + nn.Linear(inner_dim, cross_attention_dim * expansion_factor), + ) + self.expansion_factor = expansion_factor + self.cross_attention_dim = cross_attention_dim + + def forward(self, x, batch_size=0): + if x.ndim == 2: + x = rearrange(x, "(B F) C -> B F C", B=batch_size) + assert x.ndim == 3 + x = torch.mean(x, dim=1, keepdim=True) + x = self.context_embedding(x) + x = x.view(-1, self.expansion_factor, self.cross_attention_dim) + return x + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + self.embedding_dim = -1 + self.num_tokens = -1 + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + antialias=True, + ucg_rate=0.0, + unsqueeze_dim=False, + repeat_to_max_len=False, + num_image_crops=0, + output_tokens=False, + ): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) + del model.transformer + self.model = model + self.max_crops = num_image_crops + self.pad_to_max_len = self.max_crops > 0 + self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + self.antialias = antialias + + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + self.ucg_rate = ucg_rate + self.unsqueeze_dim = unsqueeze_dim + self.stored_batch = None + self.model.visual.output_tokens = output_tokens + self.output_tokens = output_tokens + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + tokens = None + if self.output_tokens: + z, tokens = z[0], z[1] + z = z.to(image.dtype) + if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): + z = ( + torch.bernoulli( + (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) + )[:, None] + * z + ) + if tokens is not None: + tokens = ( + expand_dims_like( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(tokens.shape[0], device=tokens.device) + ), + tokens, + ) + * tokens + ) + if self.unsqueeze_dim: + z = z[:, None, :] + if self.output_tokens: + assert not self.repeat_to_max_len + assert not self.pad_to_max_len + return tokens, z + if self.repeat_to_max_len: + if z.dim() == 2: + z_ = z[:, None, :] + else: + z_ = z + return repeat(z_, "b 1 d -> b n d", n=self.max_length), z + elif self.pad_to_max_len: + assert z.dim() == 3 + z_pad = torch.cat( + ( + z, + torch.zeros( + z.shape[0], + self.max_length - z.shape[1], + z.shape[2], + device=z.device, + ), + ), + 1, + ) + return z_pad, z_pad[:, 0, ...] + return z + + def encode_with_vision_transformer(self, img): + # if self.max_crops > 0: + # img = self.preprocess_by_cropping(img) + if img.dim() == 5: + assert self.max_crops == img.shape[1] + img = rearrange(img, "b n c h w -> (b n) c h w") + img = self.preprocess(img) + if not self.output_tokens: + assert not self.model.visual.output_tokens + x = self.model.visual(img) + tokens = None + else: + assert self.model.visual.output_tokens + x, tokens = self.model.visual(img) + if self.max_crops > 0: + x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) + # drop out between 0 and all along the sequence axis + x = ( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) + ) + * x + ) + if tokens is not None: + tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) + print( + f"You are running very experimental token-concat in {self.__class__.__name__}. " + f"Check what you are doing, and then remove this message." + ) + if self.output_tokens: + return x, tokens + return x + + def encode(self, text): + return self(text) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/mask_generator.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9e6e7244791f57eb3daa9d9cc97c3641df1d1979 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/mask_generator.py @@ -0,0 +1,36 @@ +from videogen_hub.pipelines.streamingt2v.model.pl_module_params_controlnet import ( + AttentionMaskParams, +) +import torch + + +class MaskGenerator: + + def __init__(self, params: AttentionMaskParams, num_frame_conditioning, num_frames): + self.params = params + self.num_frame_conditioning = num_frame_conditioning + self.num_frames = num_frames + + def get_mask(self, precision, device): + + params = self.params + if params.temporal_self_attention_only_on_conditioning: + with torch.no_grad(): + attention_mask = torch.zeros( + (1, self.num_frames, self.num_frames), + dtype=( + torch.float16 if precision.startswith("16") else torch.float32 + ), + device=device, + ) + for frame in range(self.num_frame_conditioning, self.num_frames): + attention_mask[:, frame, self.num_frame_conditioning :] = float( + "-inf" + ) + if params.temporal_self_attention_mask_included_itself: + attention_mask[:, frame, frame] = 0 + if params.temp_attend_on_uncond_include_past: + attention_mask[:, frame, :frame] = 0 + else: + attention_mask = None + return attention_mask diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/pipeline_text_to_video_w_controlnet_synth.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/pipeline_text_to_video_w_controlnet_synth.py new file mode 100644 index 0000000000000000000000000000000000000000..3a107321676a50f2cb9ae5bf185f9490d1e97e8f --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/pipeline_text_to_video_w_controlnet_synth.py @@ -0,0 +1,925 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet3DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.text_to_video_synthesis import TextToVideoSDPipelineOutput +from einops import rearrange + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import TextToVideoSDPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = TextToVideoSDPipeline.from_pretrained( + ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Spiderman is surfing" + >>> video_frames = pipe(prompt).frames + >>> video_path = export_to_video(video_frames) + >>> video_path + ``` +""" + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], output_type="list") -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + if output_type == "list": + # prepare a list of indvidual (consecutive frames) + images = images.unbind(dim=0) + images = [(image.cpu().numpy() * 255).astype("uint8") + for image in images] # f h w c + elif output_type == "pt": + pass + return images + + +class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Same as Stable Diffusion 2. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + controlnet, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** ( + len(self.vae.config.block_out_channels) - 1) + + def prepare_image( + self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance, cfg_text_image=False, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize( + (width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + image_vq_enc = self.vae.encode(rearrange( + image, "B F C W H -> (B F) C W H")).latent_dist.sample() * self.vae.config.scaling_factor + image_vq_enc = rearrange( + image_vq_enc, "(B F) C W H -> B F C W H", B=image_batch_size) + if do_classifier_free_guidance: + if cfg_text_image: + image = torch.cat([torch.zeros_like(image), image], dim=0) + else: + image = torch.cat([image] * 2) + # image_vq_enc = torch.cat([image_vq_enc] * 2) + + return image, image_vq_enc + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded + to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a + submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError( + "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + # otherwise we don't see the memory savings (but they probably exist) + torch.cuda.empty_cache() + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError( + "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + # otherwise we don't see the memory savings (but they probably exist) + torch.cuda.empty_cache() + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook( + cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + img_cond: Optional[torch.FloatTensor] = None, + img_cond_unc: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to( + dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1) + max_length = prompt_embeds.shape[1] + if img_cond is not None: + if img_cond.ndim == 2: + img_cond = img_cond.unsqueeze(1) + prompt_embeds = torch.cat([prompt_embeds, img_cond], dim=1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt( + uncond_tokens, self.tokenizer) + + # max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1) + + if img_cond_unc is not None: + if img_cond_unc.ndim == 2: + img_cond_unc = img_cond_unc.unsqueeze(1) + negative_prompt_embeds = torch.cat( + [negative_prompt_embeds, img_cond_unc], dim=1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature( + self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance( + callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if hasattr(self, "noise_generator"): + latents = self.noise_generator.sample_noise( + shape=shape, generator=generator, device=device, dtype=dtype) + elif latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def set_noise_generator(self, noise_generator): + if noise_generator is not None and noise_generator.mode != "vanilla": + self.noise_generator = noise_generator + + def reset_noise_generator_state(self): + if hasattr(self, "noise_generator") and hasattr(self.noise_generator, "reset_noise"): + self.noise_generator.reset_noise_generator_state() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + # the image input for the controlnet branch + image: Optional[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[ + int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + precision: str = "16", + mask_generator=None, + no_text_condition_control: bool = False, + weight_control_sample: float = 1.0, + use_controlnet_mask: bool = False, + skip_controlnet_branch: bool = False, + img_cond_resampler=None, + img_cond_encoder=None, + input_frames_conditioning=None, + cfg_text_image: bool = False, + use_of: bool = False, + ** kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_images_per_prompt = 1 + controlnet_mask = None + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + # import pdb + # pdb.set_trace() + + if img_cond_resampler is not None and image is not None: + bsz = image.shape[0] + image_for_conditioniong = rearrange( + input_frames_conditioning, "B F C W H -> (B F) C W H") + image_enc = img_cond_encoder(image_for_conditioniong) + img_cond = img_cond_resampler(image_enc, batch_size=bsz) + image_enc_unc = img_cond_encoder( + torch.zeros_like(image_for_conditioniong)) + img_cond_unc = img_cond_resampler(image_enc_unc, batch_size=bsz) + else: + img_cond = None + img_cond_unc = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + img_cond=img_cond, + img_cond_unc=img_cond_unc + ) + skip_conditioning = image is None or skip_controlnet_branch + # import pdb + # pdb.set_trace() + if not skip_conditioning: + num_condition_frames = image.shape[1] + image, image_vq_enc = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + cfg_text_image=cfg_text_image, + ) + if len(image.shape) == 5: + image = rearrange(image, "B F C H W -> (B F) C H W") + if use_controlnet_mask: + # num_condition_frames = all possible frames, e.g. 16 + assert num_condition_frames == num_frames + image = rearrange( + image, "(B F) C H W -> B F C H W", F=num_condition_frames) + # image = torch.cat([image, image], dim=1) + controlnet_mask = torch.zeros( + (image.shape[0], num_frames), device=image.device, dtype=image.dtype) + # TODO HARDCODED number of frames! + controlnet_mask[:, :8] = 1.0 + image = rearrange(image, "B F C H W -> (B F) C H W") + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + of_channels = 2 if use_of else 0 + num_channels_ctrl = self.unet.config.in_channels + num_channels_latents = num_channels_ctrl + of_channels + if not skip_conditioning: + image_vq_enc = rearrange( + image_vq_enc, "B F C H W -> B C F H W ", F=num_condition_frames) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if self.unet.concat: + image_latents = self.vae.encode(rearrange( + image, "B F C W H -> (B F) C W H")).latent_dist.sample() * self.vae.config.scaling_factor + image_latents = rearrange( + image_latents, "(B F) C W H -> B C F W H", B=latents.shape[0]) + image_shape = image_latents.shape + image_shape = [ax_dim for ax_dim in image_shape] + image_shape[2] = 16-image_shape[2] + image_latents = torch.cat([image_latents, torch.zeros( + image_shape, dtype=image_latents.dtype, device=image_latents.device)], dim=2) + controlnet_mask = torch.zeros( + image_latents.shape, device=image_latents.device, dtype=image_latents.dtype) + controlnet_mask[:, :, :8] = 1.0 + image_latents = image_latents * controlnet_mask + # torch.cat([latents, image_latents, controlnet_mask[:, :1]], dim=1) + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - \ + num_inference_steps * self.scheduler.order + + if mask_generator is not None: + attention_mask = mask_generator.get_mask( + device=latents.device, precision=precision) + else: + attention_mask = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat( + [latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t) + + if self.unet.concat: + latent_model_input = torch.cat([latent_model_input, image_latents.repeat( + 2, 1, 1, 1, 1), controlnet_mask[:, :1].repeat(2, 1, 1, 1, 1)], dim=1) + if not skip_conditioning: + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input[:, :num_channels_ctrl], + t, + encoder_hidden_states=prompt_embeds if (not no_text_condition_control) else torch.stack([ + prompt_embeds[0], prompt_embeds[0]]), + controlnet_cond=image, + attention_mask=attention_mask, + vq_gan=self.vae, + weight_control_sample=weight_control_sample, + return_dict=False, + controlnet_mask=controlnet_mask, + ) + else: + down_block_res_samples = None + mid_block_res_sample = None + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + down_block_additional_residuals=[ + sample.to(dtype=latent_model_input.dtype) for sample in down_block_res_samples + ] if down_block_res_samples is not None else None, + mid_block_additional_residual=mid_block_res_sample.to( + dtype=latent_model_input.dtype) if mid_block_res_sample is not None else None, + fps=None, + + ).sample + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk( + 2) + noise_pred = noise_pred_uncond + guidance_scale * \ + (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape( + bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape( + bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_step = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs) + latents = scheduler_step.prev_sample + + # reshape latents back + latents = latents[None, :].reshape( + bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents_video = latents[:, :num_channels_ctrl] + if of_channels > 0: + latents_of = latents[:, num_channels_ctrl:] + latents_of = rearrange(latents_of, "B C F W H -> (B F) C W H") + video_tensor = self.decode_latents(latents_video) + + if output_type == "pt": + video = video_tensor + elif output_type == "pt_t2v": + video = tensor2vid(video_tensor, output_type="pt") + video = rearrange(video, "f h w c -> f c h w") + elif output_type == "concat_image": + image_video = image.unsqueeze(2)[0:1].repeat([1, 1, 24, 1, 1]) + video_tensor_concat = torch.concat( + [image_video, video_tensor], dim=4) + video = tensor2vid(video_tensor_concat) + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + if of_channels == 0: + return video + else: + return video, latents_of + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/processor.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e53c39ee8aa19c50dd91e352d248c1964d32bf --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/processor.py @@ -0,0 +1,332 @@ +from einops import repeat, rearrange +from typing import Callable, Optional, Union +from .attention_processor import Attention + +# from t2v_enhanced.model.diffusers_conditional.controldiffusers.models.attention import Attention +from diffusers.utils.import_utils import is_xformers_available +from videogen_hub.pipelines.streamingt2v.model.pl_module_params_controlnet import ( + AttentionMaskParams, +) +import torch +import torch.nn.functional as F + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +def set_use_memory_efficient_attention_xformers( + model, + num_frame_conditioning: int, + num_frames: int, + attention_mask_params: AttentionMaskParams, + valid: bool = True, + attention_op: Optional[Callable] = None, +) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_processor"): + + module.set_processor( + XFormersAttnProcessor( + attention_op=attention_op, + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + attention_mask_params=attention_mask_params, + ) + ) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in model.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + +class XFormersAttnProcessor: + def __init__( + self, + attention_mask_params: AttentionMaskParams, + attention_op: Optional[Callable] = None, + num_frame_conditioning: int = None, + num_frames: int = None, + use_image_embedding: bool = False, + ): + self.attention_op = attention_op + self.num_frame_conditioning = num_frame_conditioning + self.num_frames = num_frames + self.temp_attend_on_neighborhood_of_condition_frames = ( + attention_mask_params.temp_attend_on_neighborhood_of_condition_frames + ) + self.spatial_attend_on_condition_frames = ( + attention_mask_params.spatial_attend_on_condition_frames + ) + self.use_image_embedding = use_image_embedding + + def __call__( + self, + attn: Attention, + hidden_states, + hidden_state_height=None, + hidden_state_width=None, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + key_img = None + value_img = None + hidden_states_img = None + if attention_mask is not None: + attention_mask = repeat(attention_mask, "1 F D -> B F D", B=batch_size) + + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + default_attention = not hasattr(attn, "is_spatial_attention") + if default_attention: + assert ( + not self.temp_attend_on_neighborhood_of_condition_frames + ), "special attention must be implemented with new interface" + assert ( + not self.spatial_attend_on_condition_frames + ), "special attention must be implemented with new interface" + is_spatial_attention = ( + attn.is_spatial_attention + if hasattr(attn, "is_spatial_attention") + else False + ) + use_image_embedding = ( + attn.use_image_embedding if hasattr(attn, "use_image_embedding") else False + ) + + if is_spatial_attention and use_image_embedding and attn.cross_attention_mode: + assert ( + not self.spatial_attend_on_condition_frames + ), "Not implemented together with image embedding" + + alpha = attn.alpha + encoder_hidden_states_txt = encoder_hidden_states[:, :77, :] + + encoder_hidden_states_mixed = attn.conv(encoder_hidden_states) + encoder_hidden_states_mixed = attn.conv_ln(encoder_hidden_states_mixed) + encoder_hidden_states = ( + encoder_hidden_states_txt + encoder_hidden_states_mixed * F.silu(alpha) + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if ( + not default_attention + and not is_spatial_attention + and self.temp_attend_on_neighborhood_of_condition_frames + and not attn.cross_attention_mode + ): + # normal attention + query_condition = query[:, : self.num_frame_conditioning] + query_condition = attn.head_to_batch_dim(query_condition).contiguous() + key_condition = key + value_condition = value + key_condition = attn.head_to_batch_dim(key_condition).contiguous() + value_condition = attn.head_to_batch_dim(value_condition).contiguous() + hidden_states_condition = xformers.ops.memory_efficient_attention( + query_condition, + key_condition, + value_condition, + attn_bias=None, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states_condition = hidden_states_condition.to(query.dtype) + hidden_states_condition = attn.batch_to_head_dim(hidden_states_condition) + # + query_uncondition = query[:, self.num_frame_conditioning :] + + key = key[:, : self.num_frame_conditioning] + value = value[:, : self.num_frame_conditioning] + key = rearrange( + key, + "(B W H) F C -> B W H F C", + H=hidden_state_height, + W=hidden_state_width, + ) + value = rearrange( + value, + "(B W H) F C -> B W H F C", + H=hidden_state_height, + W=hidden_state_width, + ) + + keys = [] + values = [] + for shifts_width in [-1, 0, 1]: + for shifts_height in [-1, 0, 1]: + keys.append( + torch.roll( + key, shifts=(shifts_width, shifts_height), dims=(1, 2) + ) + ) + values.append( + torch.roll( + value, shifts=(shifts_width, shifts_height), dims=(1, 2) + ) + ) + key = rearrange(torch.cat(keys, dim=3), "B W H F C -> (B W H) F C") + value = rearrange(torch.cat(values, dim=3), "B W H F C -> (B W H) F C") + + query = attn.head_to_batch_dim(query_uncondition).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = torch.cat([hidden_states_condition, hidden_states], dim=1) + elif ( + not default_attention + and is_spatial_attention + and self.spatial_attend_on_condition_frames + and not attn.cross_attention_mode + ): + # (B F) W H C -> B F W H C + query_condition = rearrange( + query, "(B F) S C -> B F S C", F=self.num_frames + ) + query_condition = query_condition[:, : self.num_frame_conditioning] + query_condition = rearrange(query_condition, "B F S C -> (B F) S C") + query_condition = attn.head_to_batch_dim(query_condition).contiguous() + + key_condition = rearrange(key, "(B F) S C -> B F S C", F=self.num_frames) + key_condition = key_condition[:, : self.num_frame_conditioning] + key_condition = rearrange(key_condition, "B F S C -> (B F) S C") + + value_condition = rearrange( + value, "(B F) S C -> B F S C", F=self.num_frames + ) + value_condition = value_condition[:, : self.num_frame_conditioning] + value_condition = rearrange(value_condition, "B F S C -> (B F) S C") + + key_condition = attn.head_to_batch_dim(key_condition).contiguous() + value_condition = attn.head_to_batch_dim(value_condition).contiguous() + hidden_states_condition = xformers.ops.memory_efficient_attention( + query_condition, + key_condition, + value_condition, + attn_bias=None, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states_condition = hidden_states_condition.to(query.dtype) + hidden_states_condition = attn.batch_to_head_dim(hidden_states_condition) + + query_uncondition = rearrange( + query, "(B F) S C -> B F S C", F=self.num_frames + ) + query_uncondition = query_uncondition[:, self.num_frame_conditioning :] + key_uncondition = rearrange(key, "(B F) S C -> B F S C", F=self.num_frames) + value_uncondition = rearrange( + value, "(B F) S C -> B F S C", F=self.num_frames + ) + key_uncondition = key_uncondition[:, self.num_frame_conditioning - 1, None] + value_uncondition = value_uncondition[ + :, self.num_frame_conditioning - 1, None + ] + # if self.trainer.training: + # import pdb + # pdb.set_trace() + # print("now") + query_uncondition = rearrange(query_uncondition, "B F S C -> (B F) S C") + key_uncondition = repeat( + rearrange(key_uncondition, "B F S C -> B (F S) C"), + "B T C -> (B F) T C", + F=self.num_frames - self.num_frame_conditioning, + ) + value_uncondition = repeat( + rearrange(value_uncondition, "B F S C -> B (F S) C"), + "B T C -> (B F) T C", + F=self.num_frames - self.num_frame_conditioning, + ) + query_uncondition = attn.head_to_batch_dim(query_uncondition).contiguous() + key_uncondition = attn.head_to_batch_dim(key_uncondition).contiguous() + value_uncondition = attn.head_to_batch_dim(value_uncondition).contiguous() + hidden_states_uncondition = xformers.ops.memory_efficient_attention( + query_uncondition, + key_uncondition, + value_uncondition, + attn_bias=None, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states_uncondition = hidden_states_uncondition.to(query.dtype) + hidden_states_uncondition = attn.batch_to_head_dim( + hidden_states_uncondition + ) + hidden_states = torch.cat( + [ + rearrange( + hidden_states_condition, + "(B F) S C -> B F S C", + F=self.num_frame_conditioning, + ), + rearrange( + hidden_states_uncondition, + "(B F) S C -> B F S C", + F=self.num_frames - self.num_frame_conditioning, + ), + ], + dim=1, + ) + hidden_states = rearrange(hidden_states, "B F S C -> (B F) S C") + else: + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/transformer_2d.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..832d2b5224234530cf057083e1636b29fcafb1de --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/transformer_2d.py @@ -0,0 +1,373 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from .attention import BasicTransformerBlock +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.modeling_utils import ModelMixin + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + use_image_embedding: bool = False, + unet_params=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate( + "norm_type!=num_embeds_ada_norm", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif ( + not self.is_input_continuous + and not self.is_input_vectorized + and not self.is_input_patches + ): + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True, + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + assert ( + sample_size is not None + ), "Transformer2DModel over discrete input must provide sample_size" + assert ( + num_vector_embeds is not None + ), "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, + embed_dim=inner_dim, + height=self.height, + width=self.width, + ) + elif self.is_input_patches: + assert ( + sample_size is not None + ), "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + is_spatial_attention=True, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear( + inner_dim, patch_size * patch_size * self.out_channels + ) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = ( + self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + ) + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=( + -1, + height, + width, + self.patch_size, + self.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height * self.patch_size, + width * self.patch_size, + ) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/transformer_temporal.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/transformer_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..fa4cb58285bbf073cb28a18569fa9d4cc73b89ba --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/transformer_temporal.py @@ -0,0 +1,193 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput + +# from diffusers.models.attention import BasicTransformerBlock +# from t2v_enhanced.model.diffusers_conditional.models.attention import BasicTransformerBlock +from diffusers.models.modeling_utils import ModelMixin +from .attention import BasicTransformerBlock + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) + Hidden states conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each TransformerBlock should contain two self-attention layers + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + is_spatial_attention=False, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + return_dict: bool = True, + attention_mask=None, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: + [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape( + batch_size, num_frames, channel, height, width + ) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( + batch_size * height * width, num_frames, channel + ) + + hidden_states = self.proj_in(hidden_states) + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + cross_attention_kwargs["hidden_state_height"] = height + cross_attention_kwargs["hidden_state_width"] = width + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + attention_mask=attention_mask, + encoder_attention_mask=attention_mask, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, channel, num_frames) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/transformer_temporal_crossattention.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/transformer_temporal_crossattention.py new file mode 100644 index 0000000000000000000000000000000000000000..53361f056cf4c2e110792b355e4512bf2dce15cf --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/transformer_temporal_crossattention.py @@ -0,0 +1,182 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput + +from diffusers.models.modeling_utils import ModelMixin + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) + Hidden states conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each TransformerBlock should contain two self-attention layers + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + ): + super().__init__() + + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + only_cross_attention=True, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: + [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape( + batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( + batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, channel, num_frames) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape( + batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/unet_3d_blocks.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..20a61fa34e6340efe4e72b339ef4933ece01da06 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/unet_3d_blocks.py @@ -0,0 +1,929 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.utils.checkpoint as checkpoint +from torch import nn +from diffusers.models.resnet import ( + Downsample2D, + ResnetBlock2D, + TemporalConvLayer, + Upsample2D, +) + +# from diffusers.models.transformer_2d import Transformer2DModel +from .transformer_2d import Transformer2DModel + +# from diffusers.models.transformer_temporal import TransformerTemporalModel +from .transformer_temporal import TransformerTemporalModel + + +# Assign gradient checkpoint function to simple variable for readability. +g_c = checkpoint.checkpoint + + +def is_video(num_frames, only_video=True): + if num_frames == 1 and not only_video: + return False + return num_frames > 1 + + +def custom_checkpoint(module, mode=None): + if mode == None: + raise ValueError("Mode for gradient checkpointing cannot be none.") + + custom_forward = None + + if mode == "resnet": + + def custom_forward(hidden_states, temb): + inputs = module(hidden_states, temb) + return inputs + + if mode == "attn": + + def custom_forward( + hidden_states, + encoder_hidden_states=None, + cross_attention_kwargs=None, + attention_mask=None, + ): + inputs = module( + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + ) + return inputs.sample + + if mode == "temp": + # If inputs are not None, we can assume that this was a single image. + # Otherwise, do temporal convolutions / attention. + def custom_forward(hidden_states, num_frames=None): + if not is_video(num_frames): + return hidden_states + else: + inputs = module(hidden_states, num_frames=num_frames) + if isinstance(module, TransformerTemporalModel): + return inputs.sample + else: + return inputs + + return custom_forward + + +def transformer_g_c(transformer, sample, num_frames): + sample = g_c( + custom_checkpoint(transformer, mode="temp"), + sample, + num_frames, + use_reentrant=False, + ) + return sample + + +def cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=False, + attention_mask=None, +): + + def ordered_g_c(idx): + + # Self and CrossAttention + if idx == 0: + return g_c( + custom_checkpoint(attn, mode="attn"), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + use_reentrant=False, + ) + + # Temporal Self and CrossAttention + if idx == 1: + return g_c( + custom_checkpoint(temp_attn, mode="temp"), + hidden_states, + num_frames, + use_reentrant=False, + ) + + # Resnets + if idx == 2: + return g_c( + custom_checkpoint(resnet, mode="resnet"), + hidden_states, + temb, + use_reentrant=False, + ) + + # Temporal Convolutions + if idx == 3: + return g_c( + custom_checkpoint(temp_conv, mode="temp"), + hidden_states, + num_frames, + use_reentrant=False, + ) + + # Here we call the function depending on the order in which they are called. + # For some layers, the orders are different, so we access the appropriate one by index. + + if not inverse_temp: + for idx in [0, 1, 2, 3]: + hidden_states = ordered_g_c(idx) + else: + for idx in [2, 3, 0, 1]: + hidden_states = ordered_g_c(idx) + + return hidden_states + + +def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames): + hidden_states = g_c( + custom_checkpoint(resnet, mode="resnet"), + hidden_states, + temb, + use_reentrant=False, + ) + hidden_states = g_c( + custom_checkpoint(temp_conv, mode="temp"), + hidden_states, + num_frames, + use_reentrant=False, + ) + return hidden_states + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_image_embedding=False, + unet_params=None, +): + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock3D" + ) + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_image_embedding=False, + unet_params=None, +): + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock3D" + ) + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=False, + use_image_embedding=False, + unet_params=None, + ): + super().__init__() + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [TemporalConvLayer(in_channels, in_channels, dropout=0.1)] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append(TemporalConvLayer(in_channels, in_channels, dropout=0.1)) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + if self.gradient_checkpointing: + hidden_states = up_down_g_c( + self.resnets[0], self.temp_convs[0], hidden_states, temb, num_frames + ) + else: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + attention_mask=attention_mask, + ).sample + + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_image_embedding=False, + unet_params=None, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer(out_channels, out_channels, dropout=0.1) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + layer_idx = 0 + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True, + ) + else: + hidden_states = resnet(hidden_states, temb) + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + if num_frames > 1: + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + attention_mask=attention_mask, + ).sample + layer_idx += 1 + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + self.gradient_checkpointing = False + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer(out_channels, out_channels, dropout=0.1) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + if self.gradient_checkpointing: + hidden_states = up_down_g_c( + resnet, temp_conv, hidden_states, temb, num_frames + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_image_embedding=False, + unet_params=None, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer(out_channels, out_channels, dropout=0.1) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True, + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + attention_mask=attention_mask, + ).sample + output_states += (hidden_states,) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + self.gradient_checkpointing = False + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer(out_channels, out_channels, dropout=0.1) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + num_frames=1, + ): + output_states = () + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = up_down_g_c( + resnet, temp_conv, hidden_states, temb, num_frames + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + output_states += (hidden_states,) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + output_states += (hidden_states,) + + return hidden_states, output_states diff --git a/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/unet_3d_condition.py b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/unet_3d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..a8261fba9e73a5a7bcbe58c259fe14d806c43f58 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/diffusers_conditional/models/controlnet/unet_3d_condition.py @@ -0,0 +1,708 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin + +# from diffusers.models.transformer_temporal import TransformerTemporalModel +from .transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + transformer_g_c, +) +from .conditioning import ConditionalModel +from einops import rearrange +from videogen_hub.pipelines.streamingt2v.model.layers.conv_channel_extension import ( + Conv2D_ExtendedChannels, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + merging_mode: str = "addition", + use_image_embedding: bool = False, + use_fps_conditioning: bool = False, + unet_params=None, + ): + super().__init__() + channel_expansion = unet_params.use_of + self.concat = unet_params.concat + self.use_image_tokens = unet_params.use_image_tokens_main + self.image_encoder_name = type(unet_params.image_encoder).__name__ + self.use_image_embedding = use_image_embedding + self.sample_size = sample_size + self.gradient_checkpointing = False + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + """ + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + """ + self.conv_in = Conv2D_ExtendedChannels( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + in_channel_extension=5 if self.concat else 0, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + self.use_fps_conditioning = use_fps_conditioning + if use_fps_conditioning: + fps_embed_dim = block_out_channels[0] * 4 + fps_input_dim = block_out_channels[0] + self.fps_embedding = TimestepEmbedding( + fps_input_dim, fps_embed_dim, act_fn=act_fn + ) + self.fps_proj = Timesteps(block_out_channels[0], True, 0) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + self.merging_mode = merging_mode + print("self.merging_mode", self.merging_mode) + if self.merging_mode.startswith("attention"): + self.cross_attention_merger_down_blocks = nn.ModuleList([]) + self.cross_attention_merger_mid_block = nn.ModuleList([]) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + self.down_blocks.append(down_block) + + if self.merging_mode.startswith("attention"): + for idx in range(3): + self.cross_attention_merger_down_blocks.append( + ConditionalModel( + input_channels=( + input_channel if idx == 0 else output_channel + ), + conditional_model=self.merging_mode.split("attention_")[1], + ) + ) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + if self.merging_mode.startswith("attention"): + self.cross_attention_merger_mid_block = ConditionalModel( + input_channels=block_out_channels[-1], + conditional_model=self.merging_mode.split("attention_")[1], + ) + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + """ + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + """ + self.conv_out = Conv2D_ExtendedChannels( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, + out_channel_extension=2 if channel_expansion else 0, + ) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = ( + num_slicable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, value=False): + self.gradient_checkpointing = value + self.mid_block.gradient_checkpointing = value + for module in self.down_blocks + self.up_blocks: + if isinstance( + module, + (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D), + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + fps: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + """ + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + """ + debug = False + if self.use_fps_conditioning: + + if torch.is_tensor(fps): + assert (fps > -1).all(), "FPS not set" + if len(fps.shape) == 0: + fps = fps[None].to(sample.device) + else: + assert fps > -1, "FPS not set" + is_mps = sample.device.type == "mps" + if isinstance(fps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + fps = torch.tensor([fps], dtype=dtype, device=sample.device) + fps = fps.expand(sample.shape[0]) + fps_proj = self.fps_proj(fps) + fps_proj = fps_proj.to(dtype=self.dtype) + fps_emb = self.fps_embedding(fps_proj) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + batch_size = sample.shape[0] + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + if self.use_fps_conditioning: + fps_emb = fps_emb.repeat_interleave(repeats=num_frames, dim=0) + emb = emb + fps_emb + + if not self.use_image_tokens and encoder_hidden_states.shape[1] > 77: + encoder_hidden_states = encoder_hidden_states[:, :77] + # print(f"MAIN with tokens = {encoder_hidden_states.shape[1]}") + if encoder_hidden_states.shape[1] > 77: + # assert ( + # encoder_hidden_states.shape[1]-77) % num_frames == 0, f"Encoder shape {encoder_hidden_states.shape}. Num frames = {num_frames}" + context_text, context_img = ( + encoder_hidden_states[:, :77, :], + encoder_hidden_states[:, 77:, :], + ) + context_text = context_text.repeat_interleave(repeats=num_frames, dim=0) + + if self.image_encoder_name == "FrozenOpenCLIPImageEmbedder": + context_img = context_img.repeat_interleave(repeats=num_frames, dim=0) + else: + context_img = rearrange( + context_img, "b (t l) c -> (b t) l c", t=num_frames + ) + + encoder_hidden_states = torch.cat([context_text, context_img], dim=1) + else: + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + repeats=num_frames, dim=0 + ) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape( + (sample.shape[0] * num_frames, -1) + sample.shape[3:] + ) + sample = self.conv_in(sample) + + if num_frames > 1: + if self.gradient_checkpointing: + sample = transformer_g_c(self.transformer_in, sample, num_frames) + else: + sample = self.transformer_in( + sample, num_frames=num_frames, attention_mask=attention_mask + ).sample + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, num_frames=num_frames + ) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + if self.merging_mode == "addition": + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples += (down_block_res_sample,) + elif self.merging_mode.startswith("attention"): + for ( + down_block_res_sample, + down_block_additional_residual, + merger, + ) in zip( + down_block_res_samples, + down_block_additional_residuals, + self.cross_attention_merger_down_blocks, + ): + + down_block_res_sample = merger( + rearrange( + down_block_res_sample, + "(B F) C H W -> B F C H W", + B=batch_size, + ), + rearrange( + down_block_additional_residual, + "(B F) C H W -> B F C H W", + B=batch_size, + ), + ) + down_block_res_sample = rearrange( + down_block_res_sample, "B F C H W -> (B F) C H W" + ) + new_down_block_res_samples += (down_block_res_sample,) + elif self.merging_mode == "overwrite": + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + if self.merging_mode == "addition": + sample = sample + mid_block_additional_residual + elif self.merging_mode == "overwrite": + sample = sample + mid_block_additional_residual + elif self.merging_mode.startswith("attention"): + sample = self.cross_attention_merger_mid_block( + rearrange(sample, "(B F) C H W -> B F C H W", B=batch_size), + rearrange( + mid_block_additional_residual, + "(B F) C H W -> B F C H W", + B=batch_size, + ), + ) + sample = rearrange(sample, "B F C H W -> (B F) C H W") + + if debug: + upblockout = (sample,) + # 5. up + # import pdb + # pdb.set_trace() + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample, output_states = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, output_states = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + if debug: + upblockout += output_states + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = ( + sample[None, :] + .reshape((-1, num_frames) + sample.shape[1:]) + .permute(0, 2, 1, 3, 4) + ) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/flags.py b/src/videogen_hub/pipelines/streamingt2v/model/flags.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4edac51898752e3a1069eb93ad23ea8b7c7b0a --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/flags.py @@ -0,0 +1 @@ +TORCH_DISTRIBUTED_DEBUG = DETAIL diff --git a/src/videogen_hub/pipelines/streamingt2v/model/layers/__init__.py b/src/videogen_hub/pipelines/streamingt2v/model/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/model/layers/conv_channel_extension.py b/src/videogen_hub/pipelines/streamingt2v/model/layers/conv_channel_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..bed009c21f5958758d93af9af397754b9fab7a6f --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/layers/conv_channel_extension.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +from typing import Union +from torch.nn.common_types import _size_2_t + + +class Conv2D_SubChannels(nn.Conv2d): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None, + ) -> None: + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, device, dtype) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + + if prefix+"weight" in state_dict and ((state_dict[prefix+"weight"].shape[0] > self.out_channels) or (state_dict[prefix+"weight"].shape[1] > self.in_channels)): + print( + f"Model checkpoint has too many channels. Excluding channels of convolution {prefix}.") + if self.bias is not None: + bias = state_dict[prefix+"bias"][:self.out_channels] + state_dict[prefix+"bias"] = bias + del bias + + weight = state_dict[prefix+"weight"] + state_dict[prefix+"weight"] = weight[:self.out_channels, + :self.in_channels] + del weight + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +class Conv2D_ExtendedChannels(nn.Conv2d): + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None, + in_channel_extension: int = 0, + out_channel_extension: int = 0, + ) -> None: + super().__init__(in_channels+in_channel_extension, out_channels+out_channel_extension, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, device, dtype) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + print(f"Call extend channel loader with {prefix}") + if prefix+"weight" in state_dict and (state_dict[prefix+"weight"].shape[0] < self.out_channels or state_dict[prefix+"weight"].shape[1] < self.in_channels): + print( + f"Model checkpoint has insufficient channels. Extending channels of convolution {prefix} by adding zeros.") + if self.bias is not None: + bias = state_dict[prefix+"bias"] + state_dict[prefix+"bias"] = torch.cat( + [bias, torch.zeros(self.out_channels-len(bias), dtype=bias.dtype, layout=bias.layout, device=bias.device)]) + del bias + + weight = state_dict[prefix+"weight"] + extended_weight = torch.zeros(self.out_channels, self.in_channels, + weight.shape[2], weight.shape[3], device=weight.device, dtype=weight.dtype, layout=weight.layout) + extended_weight[:weight.shape[0], :weight.shape[1]] = weight + state_dict[prefix+"weight"] = extended_weight + del extended_weight + del weight + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +if __name__ == "__main__": + class MyModel(nn.Module): + + def __init__(self, conv_type: str, c_in, c_out, in_extension, out_extension) -> None: + super().__init__() + + if not conv_type == "normal": + + self.conv1 = Conv2D_ExtendedChannels( + c_in, c_out, 3, padding=1, in_channel_extension=in_extension, out_channel_extension=out_extension, bias=True) + + else: + self.conv1 = nn.Conv2d(c_in, c_out, 3, padding=1, bias=True) + + def forward(self, x): + return self.conv1(x) + + c_in = 9 + c_out = 12 + c_in_ext = 0 + c_out_ext = 3 + model = MyModel("normal", c_in, c_out, c_in_ext, c_out_ext) + + input = torch.randn((4, c_in+c_in_ext, 128, 128)) + out_normal = model(input[:, :c_in]) + torch.save(model.state_dict(), "model_dummy.py") + + model_2 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext) + model_2.load_state_dict(torch.load("model_dummy.py")) + out_model_2 = model_2(input) + out_special = out_model_2[:, :c_out] + + out_new = out_model_2[:, c_out:] + model_3 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext) + model_3.load_state_dict(model_2.state_dict()) + # out_model_2 = model_2(input) + # out_special = out_model_2[:, :c_out] + + print( + f"Difference: Forward pass with extended convolution minus initial convolution: {(out_normal-out_special).abs().max()}") + + print(f"Compared tensors with shape: ", + out_normal.shape, out_special.shape) + + if model_3.conv1.bias is not None: + criterion = nn.MSELoss() + + before_opt = model_3.conv1.bias.detach().clone() + target = torch.ones_like(out_model_2) + optimizer = torch.optim.SGD( + model_3.parameters(), lr=0.01, momentum=0.9) + for iter in range(10): + optimizer.zero_grad() + out = model_3(input) + loss = criterion(out, target) + loss.backward() + optimizer.step() + print( + f"Weights before and after are the same? {before_opt[c_out:].detach()} | {model_3.conv1.bias[c_out:].detach()} ") + print(model_3.conv1.bias, model_2.conv1.bias) diff --git a/src/videogen_hub/pipelines/streamingt2v/model/pl_module_extension.py b/src/videogen_hub/pipelines/streamingt2v/model/pl_module_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..544465e746e4d0b70b7a4d85c307ff0ff428e421 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/pl_module_extension.py @@ -0,0 +1,297 @@ +import torch +from copy import deepcopy +from einops import repeat +import math + + +class FrameConditioning(): + def __init__(self, + add_frame_to_input: bool = False, + add_frame_to_layers: bool = False, + fill_zero: bool = False, + randomize_mask: bool = False, + concatenate_mask: bool = False, + injection_probability: float = 0.9, + ) -> None: + self.use = None + self.add_frame_to_input = add_frame_to_input + self.add_frame_to_layers = add_frame_to_layers + self.fill_zero = fill_zero + self.randomize_mask = randomize_mask + self.concatenate_mask = concatenate_mask + self.injection_probability = injection_probability + self.add_frame_to_input or self.add_frame_to_layers + + assert not add_frame_to_layers or not add_frame_to_input + + def set_random_mask(self, random_mask: bool): + frame_conditioning = deepcopy(self) + frame_conditioning.randomize_mask = random_mask + return frame_conditioning + + @property + def use(self): + return self.add_frame_to_input or self.add_frame_to_layers + + @use.setter + def use(self, value): + if value is not None: + raise NotImplementedError("Direct access not allowed") + + def attach_video_frames(self, pl_module, z_0: torch.Tensor = None, batch: torch.Tensor = None, random_mask: bool = False): + assert self.fill_zero, "Not filling with zero not implemented yet" + n_frames_inference = self.inference_params.video_length + with torch.no_grad(): + if z_0 is None: + assert batch is not None + z_0 = pl_module.encode_frame(batch) + assert n_frames_inference == z_0.shape[1], "For frame injection, the number of frames sampled by the dataloader must match the number of frames used for video generation" + shape = list(z_0.shape) + + shape[1] = pl_module.inference_params.video_length + M = torch.zeros(shape, dtype=z_0.dtype, + device=pl_module.device) # [B F C W H] + bsz = z_0.shape[0] + if random_mask: + p_inject_frame = self.injection_probability + use_masks = torch.bernoulli( + torch.tensor(p_inject_frame).repeat(bsz)).long() + keep_frame_idx = torch.randint( + 0, n_frames_inference, (bsz,), device=pl_module.device).long() + else: + use_masks = torch.ones((bsz,), device=pl_module.device).long() + # keep only first frame + keep_frame_idx = 0 * use_masks + frame_idx = [] + + for batch_idx, (keep_frame, use_mask) in enumerate(zip(keep_frame_idx, use_masks)): + M[batch_idx, keep_frame] = use_mask + frame_idx.append(keep_frame if use_mask == 1 else -1) + + x0 = z_0*M + if self.concatenate_mask: + # flatten mask + M = M[:, :, 0, None] + x0 = torch.cat([x0, M], dim=2) + if getattr(pl_module.opt_params.noise_decomposition, "use", False) and random_mask: + assert x0.shape[0] == 1, "randomizing frame injection with noise decomposition not implemented for batch size >1" + return x0, frame_idx + + +class NoiseDecomposition(): + + def __init__(self, + use: bool = False, + random_frame: bool = False, + lambda_f: float = 0.5, + use_base_model: bool = True, + ): + self.use = use + self.random_frame = random_frame + self.lambda_f = lambda_f + self.use_base_model = use_base_model + + def get_loss(self, x0, unet_base, unet, noise_scheduler, frame_idx, z_t_base, timesteps, encoder_hidden_states, base_noise, z_t_residual, composed_noise): + if x0 is not None: + # x0.shape = [B,F,C,W,H], if extrapolation_params.fill_zero=true, only one frame per batch non-zero + assert not self.random_frame + + # TODO add x0 injection + x0_base = [] + for batch_idx, frame in enumerate(frame_idx): + x0_base.append(x0[batch_idx, frame, None, None]) + + x0_base = torch.cat(x0_base, dim=0) + x0_residual = repeat( + x0[:, 0], "B C W H -> B F C W H", F=x0.shape[1]-1) + else: + x0_residual = None + + if self.use_base_model: + base_pred = unet_base(z_t_base, timesteps, + encoder_hidden_states, x0=x0_base).sample + else: + base_pred = base_noise + + timesteps_alphas = [ + noise_scheduler.alphas_cumprod[t.cpu()] for t in timesteps] + timesteps_alphas = torch.stack( + timesteps_alphas).to(base_pred.device) + timesteps_alphas = repeat(timesteps_alphas, "B -> B F C W H", + F=base_pred.shape[1], C=base_pred.shape[2], W=base_pred.shape[3], H=base_pred.shape[4]) + base_correction = math.sqrt( + lambda_f) * torch.sqrt(1-timesteps_alphas) * base_pred + + z_t_residual_dash = z_t_residual - base_correction + + residual_pred = unet( + z_t_residual_dash, timesteps, encoder_hidden_states, x0=x0_residual).sample + composed_pred = math.sqrt( + lambda_f)*base_pred.detach() + math.sqrt(1-lambda_f) * residual_pred + + loss_residual = torch.nn.functional.mse_loss( + composed_noise.float(), composed_pred.float(), reduction=reduction) + if self.use_base_model: + loss_base = torch.nn.functional.mse_loss( + base_noise.float(), base_pred.float(), reduction=reduction) + loss = loss_residual+loss_base + else: + loss = loss_residual + return loss + + def add_noise(self, z_base, base_noise, z_residual, composed_noise, noise_scheduler, timesteps): + z_t_base = noise_scheduler.add_noise( + z_base, base_noise, timesteps) + z_t_residual = noise_scheduler.add_noise( + z_residual, composed_noise, timesteps) + return z_t_base, z_t_residual + + def split_latent_into_base_residual(self, z_0, pl_module, noise_generator): + if self.random_frame: + raise NotImplementedError("Must be synced with x0 mask!") + fr_select = torch.randint( + 0, z_0.shape[1], (bsz,), device=pl_module.device).long() + z_base = z_0[:, fr_Select, None] + fr_residual = [fr for fr in range( + z_0.shape[1]) if fr != fr_select] + z_residual = z_0[:, fr_residual, None] + else: + if not pl_module.unet_params.frame_conditioning.randomize_mask: + z_base = z_0[:, 0, None] + z_residual = z_0[:, 1:] + else: + z_base = [] + for batch_idx, frame_at_batch in enumerate(frame_idx): + z_base.append( + z_0[batch_idx, frame_at_batch, None, None]) + z_base = torch.cat(z_base, dim=0) + # z_residual = z_0[[:, 1:] + z_residual = [] + + for batch_idx, frame_idx_batch in enumerate(frame_idx): + z_residual_batch = [] + for frame in range(z_0.shape[1]): + if frame_idx_batch != frame: + z_residual_batch.append( + z_0[batch_idx, frame, None, None]) + z_residual_batch = torch.cat( + z_residual_batch, dim=1) + z_residual.append(z_residual_batch) + z_residual = torch.cat(z_residual, dim=0) + base_noise = noise_generator.sample_noise(z_base) # b_t + residual_noise = noise_generator.sample_noise(z_residual) # r^f_t + lambda_f = self.lambda_f + composed_noise = math.sqrt( + lambda_f) * base_noise + math.sqrt(1-lambda_f) * residual_noise # dimension issue? + + return z_base, base_noise, z_residual, composed_noise + + +class NoiseGenerator(): + + def __init__(self, mode="vanilla") -> None: + self.mode = mode + + def set_seed(self, seed: int): + self.seed = seed + + def reset_seed(self, seed: int): + pass + + def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None): + + assert (z_0 is not None) != ( + shape is not None), f"either z_0 must be None, or shape must be None. Both provided." + kwargs = {} + + if z_0 is None: + if device is not None: + kwargs["device"] = device + if dtype is not None: + kwargs["dtype"] = dtype + + else: + kwargs["device"] = z_0.device + kwargs["dtype"] = z_0.dtype + shape = z_0.shape + + if generator is not None: + kwargs["generator"] = generator + + B, F, C, W, H = shape + + if self.mode == "vanilla": + noise = torch.randn( + shape, **kwargs) + elif self.mode == "free_noise": + noise = torch.randn(shape, **kwargs) + if noise.shape[1] > 4: + # HARD CODED + noise = noise[:, :8] + noise = torch.cat( + [noise, noise[:, torch.randperm(noise.shape[1])]], dim=1) + elif noise.shape[2] > 4: + noise = noise[:, :, :8] + noise = torch.cat( + [noise, noise[:, :, torch.randperm(noise.shape[2])]], dim=2) + else: + raise NotImplementedError( + f"Shape of noise vector not as expected {noise.shape}") + elif self.mode == "equal": + shape = list(shape) + shape[1] = 1 + noise_init = torch.randn( + shape, **kwargs) + shape[1] = F + noise = torch.zeros( + shape, device=noise_init.device, dtype=noise_init.dtype) + for fr in range(F): + noise[:, fr] = noise_init[:, 0] + elif self.mode == "fusion": + shape = list(shape) + shape[1] = 1 + noise_init = torch.randn( + shape, **kwargs) + noises = [] + noises.append(noise_init) + for fr in range(F-1): + + shift = 2*(fr+1) + local_copy = noise_init + shifted_noise = torch.cat( + [local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3) + noises.append(math.sqrt(0.2)*shifted_noise + + math.sqrt(1-0.2)*torch.rand(shape, **kwargs)) + noise = torch.cat(noises, dim=1) + + elif self.mode == "motion_dynamics" or self.mode == "equal_noise_per_sequence": + + shape = list(shape) + normal_frames = 1 + shape[1] = normal_frames + init_noise = torch.randn( + shape, **kwargs) + noises = [] + noises.append(init_noise) + init_noise = init_noise[:, -1, None] + print(f"UPDATE with noise = {init_noise.shape}") + + if self.mode == "motion_dynamics": + for fr in range(F-normal_frames): + + shift = 2*(fr+1) + print(fr, shift) + local_copy = init_noise + shifted_noise = torch.cat( + [local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3) + noises.append(shifted_noise) + elif self.mode == "equal_noise_per_sequence": + for fr in range(F-1): + noises.append(init_noise) + else: + raise NotImplementedError() + # noises[0] = noises[0] * 0 + noise = torch.cat(noises, dim=1) + print(noise.shape) + + return noise diff --git a/src/videogen_hub/pipelines/streamingt2v/model/pl_module_params_controlnet.py b/src/videogen_hub/pipelines/streamingt2v/model/pl_module_params_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..089310e6ef03396b2482d9c1199f248c4e5751c8 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/pl_module_params_controlnet.py @@ -0,0 +1,383 @@ +from typing import Union, Any, Dict, List, Optional, Callable +from . import pl_module_extension +from .diffusers_conditional.models.controlnet.image_embedder import AbstractEncoder +from .requires_grad_setter import LayerConfig as LayerConfigNew +from . import video_noise_generator + + +def auto_str(cls): + def __str__(self): + return "%s(%s)" % ( + type(self).__name__, + ", ".join("%s=%s" % item for item in vars(self).items()), + ) + + cls.__str__ = __str__ + return cls + + +class LayerConfig: + def __init__( + self, + update_with_full_lr: Optional[Union[List[str], List[List[str]]]] = None, + exclude: Optional[List[str]] = None, + deactivate_all_grads: bool = True, + ) -> None: + self.deactivate_all_grads = deactivate_all_grads + if exclude is not None: + self.exclude = exclude + if update_with_full_lr is not None: + self.update_with_full_lr = update_with_full_lr + + def __str__(self) -> str: + str = f"Deactivate all gradients first={self.deactivate_all_grads}. " + if hasattr(self, "update_with_full_lr"): + str += f"Then activating gradients for: {self.update_with_full_lr}. " + if hasattr(self, "exclude"): + str += f"Finally, excluding: {self.exclude}. " + return str + + +class OptimizerParams: + def __init__( + self, + learning_rate: float, + # Default value due to legacy + layers_config: Union[LayerConfig, LayerConfigNew] = None, + layers_config_base: LayerConfig = None, # Default value due to legacy + use_warmup: bool = False, + warmup_steps: int = 10000, + warmup_start_factor: float = 1e-5, + learning_rate_spatial: float = 0.0, + use_8_bit_adam: bool = False, + noise_generator: Union[ + pl_module_extension.NoiseGenerator, video_noise_generator.NoiseGenerator + ] = None, + noise_decomposition: pl_module_extension.NoiseDecomposition = None, + perceptual_loss: bool = False, + noise_offset: float = 0.0, + split_opt_by_node: bool = False, + reset_prediction_type_to_eps: bool = False, + train_val_sampler_may_differ: bool = False, + measure_similarity: bool = False, + similarity_loss: bool = False, + similarity_loss_weight: float = 1.0, + loss_conditional_weight: float = 0.0, + loss_conditional_weight_convex: bool = False, + loss_conditional_change_after_step: int = 0, + mask_conditional_frames: bool = False, + sample_from_noise: bool = True, + mask_alternating: bool = False, + uncondition_freq: int = -1, + no_text_condition_control: bool = False, + inject_image_into_input: bool = False, + inject_at_T: bool = False, + resampling_steps: int = 1, + control_freq_in_resample: int = 1, + resample_to_T: bool = False, + adaptive_loss_reweight: bool = False, + load_resampler_from_ckpt: str = "", + skip_controlnet_branch: bool = False, + use_fps_conditioning: bool = False, + num_frame_embeddings_range: int = 16, + start_frame_training: int = 0, + start_frame_ctrl: int = 0, + load_trained_base_model_and_resampler_from_ckpt: str = "", + load_trained_controlnet_from_ckpt: str = "", + # fill_up_frame_to_video: bool = False, + ) -> None: + self.use_warmup = use_warmup + self.warmup_steps = warmup_steps + self.warmup_start_factor = warmup_start_factor + self.learning_rate_spatial = learning_rate_spatial + self.learning_rate = learning_rate + self.use_8_bit_adam = use_8_bit_adam + self.layers_config = layers_config + self.noise_generator = noise_generator + self.perceptual_loss = perceptual_loss + self.noise_decomposition = noise_decomposition + self.noise_offset = noise_offset + self.split_opt_by_node = split_opt_by_node + self.reset_prediction_type_to_eps = reset_prediction_type_to_eps + self.train_val_sampler_may_differ = train_val_sampler_may_differ + self.measure_similarity = measure_similarity + self.similarity_loss = similarity_loss + self.similarity_loss_weight = similarity_loss_weight + self.loss_conditional_weight = loss_conditional_weight + self.loss_conditional_change_after_step = loss_conditional_change_after_step + self.mask_conditional_frames = mask_conditional_frames + self.loss_conditional_weight_convex = loss_conditional_weight_convex + self.sample_from_noise = sample_from_noise + self.layers_config_base = layers_config_base + self.mask_alternating = mask_alternating + self.uncondition_freq = uncondition_freq + self.no_text_condition_control = no_text_condition_control + self.inject_image_into_input = inject_image_into_input + self.inject_at_T = inject_at_T + self.resampling_steps = resampling_steps + self.control_freq_in_resample = control_freq_in_resample + self.resample_to_T = resample_to_T + self.adaptive_loss_reweight = adaptive_loss_reweight + self.load_resampler_from_ckpt = load_resampler_from_ckpt + self.skip_controlnet_branch = skip_controlnet_branch + self.use_fps_conditioning = use_fps_conditioning + self.num_frame_embeddings_range = num_frame_embeddings_range + self.start_frame_training = start_frame_training + self.load_trained_base_model_and_resampler_from_ckpt = ( + load_trained_base_model_and_resampler_from_ckpt + ) + self.load_trained_controlnet_from_ckpt = load_trained_controlnet_from_ckpt + self.start_frame_ctrl = start_frame_ctrl + if start_frame_ctrl < 0: + print("new format start frame cannot be negative") + exit() + + # self.fill_up_frame_to_video = fill_up_frame_to_video + + @property + def learning_rate_spatial(self): + return self._learning_rate_spatial + + # legacy code that maps the state None or '-1' to '0.0' + # so 0.0 indicated no spatial learning rate is selected + @learning_rate_spatial.setter + def learning_rate_spatial(self, value): + if value is None or value == -1: + value = 0 + self._learning_rate_spatial = value + + +# Legacy class +class SchedulerParams: + def __init__( + self, + use_warmup: bool = False, + warmup_steps: int = 10000, + warmup_start_factor: float = 1e-5, + ) -> None: + self.use_warmup = use_warmup + self.warmup_steps = warmup_steps + self.warmup_start_factor = warmup_start_factor + + +class CrossFrameAttentionParams: + + def __init__(self, attent_on: List[int], masking=False) -> None: + self.attent_on = attent_on + self.masking = masking + + +class InferenceParams: + def __init__( + self, + width: int, + height: int, + video_length: int, + guidance_scale: float = 7.5, + use_dec_scaling: bool = True, + frame_rate: int = 2, + num_inference_steps: int = 50, + eta: float = 0.0, + n_autoregressive_generations: int = 1, + mode: str = "long_video", + start_from_real_input: bool = True, + eval_loss_metrics: bool = False, + scheduler_cls: str = "", + negative_prompt: str = "", + conditioning_from_all_past: bool = False, + validation_samples: int = 80, + conditioning_type: str = "last_chunk", + result_formats: List[str] = ["eval_gif", "gif", "mp4"], + concat_video: bool = True, + seed: int = 33, + ): + self.width = width + self.height = height + self.video_length = ( + video_length if isinstance(video_length, int) else int(video_length) + ) + self.guidance_scale = guidance_scale + self.use_dec_scaling = use_dec_scaling + self.frame_rate = frame_rate + self.num_inference_steps = num_inference_steps + self.eta = eta + self.negative_prompt = negative_prompt + self.n_autoregressive_generations = n_autoregressive_generations + self.mode = mode + self.start_from_real_input = start_from_real_input + self.eval_loss_metrics = eval_loss_metrics + self.scheduler_cls = scheduler_cls + self.conditioning_from_all_past = conditioning_from_all_past + self.validation_samples = validation_samples + self.conditioning_type = conditioning_type + self.result_formats = result_formats + self.concat_video = concat_video + self.seed = seed + + def to_dict(self): + + keys = [ + entry + for entry in dir(self) + if not callable(getattr(self, entry)) and not entry.startswith("__") + ] + + result_dict = {} + for key in keys: + result_dict[key] = getattr(self, key) + return result_dict + + +@auto_str +class AttentionMaskParams: + + def __init__( + self, + temporal_self_attention_only_on_conditioning: bool = False, + temporal_self_attention_mask_included_itself: bool = False, + spatial_attend_on_condition_frames: bool = False, + temp_attend_on_neighborhood_of_condition_frames: bool = False, + temp_attend_on_uncond_include_past: bool = False, + ) -> None: + self.temporal_self_attention_mask_included_itself = ( + temporal_self_attention_mask_included_itself + ) + self.spatial_attend_on_condition_frames = spatial_attend_on_condition_frames + self.temp_attend_on_neighborhood_of_condition_frames = ( + temp_attend_on_neighborhood_of_condition_frames + ) + self.temporal_self_attention_only_on_conditioning = ( + temporal_self_attention_only_on_conditioning + ) + self.temp_attend_on_uncond_include_past = temp_attend_on_uncond_include_past + + assert ( + not temp_attend_on_neighborhood_of_condition_frames + or not temporal_self_attention_only_on_conditioning + ) + + +class UNetParams: + + def __init__( + self, + conditioning_embedding_out_channels: List[int], + ckpt_spatial_layers: str = "", + pipeline_repo: str = "", + unet_from_diffusers: bool = True, + spatial_latent_input: bool = False, + num_frame_conditioning: int = 1, + pipeline_class: str = "t2v_enhanced.model.model.controlnet.pipeline_text_to_video_w_controlnet_synth.TextToVideoSDPipeline", + frame_expansion: str = "last_frame", + downsample_controlnet_cond: bool = True, + num_frames: int = 1, + pre_transformer_in_cond: bool = False, + num_tranformers: int = 1, + zero_conv_3d: bool = False, + merging_mode: str = "addition", + compute_only_conditioned_frames: bool = False, + condition_encoder: str = "", + zero_conv_mode: str = "2d", + clean_model: bool = False, + merging_mode_base: str = "addition", + attention_mask_params: AttentionMaskParams = None, + attention_mask_params_base: AttentionMaskParams = None, + modelscope_input_format: bool = True, + temporal_self_attention_only_on_conditioning: bool = False, + temporal_self_attention_mask_included_itself: bool = False, + use_post_merger_zero_conv: bool = False, + weight_control_sample: float = 1.0, + use_controlnet_mask: bool = False, + random_mask_shift: bool = False, + random_mask: bool = False, + use_resampler: bool = False, + unet_from_pipe: bool = False, + unet_operates_on_2d: bool = False, + image_encoder: str = "CLIP", + use_standard_attention_processor: bool = True, + num_frames_before_chunk: int = 0, + resampler_type: str = "single_frame", + resampler_cls: str = "", + resampler_merging_layers: int = 1, + image_encoder_obj: AbstractEncoder = None, + cfg_text_image: bool = False, + aggregation: str = "last_out", + resampler_random_shift: bool = False, + img_cond_alpha_per_frame: bool = False, + num_control_input_frames: int = -1, + use_image_encoder_normalization: bool = False, + use_of: bool = False, + ema_param: float = -1.0, + concat: bool = False, + use_image_tokens_main: bool = True, + use_image_tokens_ctrl: bool = False, + ): + + self.ckpt_spatial_layers = ckpt_spatial_layers + self.pipeline_repo = pipeline_repo + self.unet_from_diffusers = unet_from_diffusers + self.spatial_latent_input = spatial_latent_input + self.pipeline_class = pipeline_class + self.num_frame_conditioning = num_frame_conditioning + if num_control_input_frames == -1: + self.num_control_input_frames = num_frame_conditioning + else: + self.num_control_input_frames = num_control_input_frames + + self.conditioning_embedding_out_channels = conditioning_embedding_out_channels + self.frame_expansion = frame_expansion + self.downsample_controlnet_cond = downsample_controlnet_cond + self.num_frames = num_frames + self.pre_transformer_in_cond = pre_transformer_in_cond + self.num_tranformers = num_tranformers + self.zero_conv_3d = zero_conv_3d + self.merging_mode = merging_mode + self.compute_only_conditioned_frames = compute_only_conditioned_frames + self.clean_model = clean_model + self.condition_encoder = condition_encoder + self.zero_conv_mode = zero_conv_mode + self.merging_mode_base = merging_mode_base + self.modelscope_input_format = modelscope_input_format + assert ( + not temporal_self_attention_only_on_conditioning + ), "This parameter is only here for backward compatibility. Set AttentionMaskParams instead." + assert ( + not temporal_self_attention_mask_included_itself + ), "This parameter is only here for backward compatibility. Set AttentionMaskParams instead." + if attention_mask_params is not None and attention_mask_params_base is None: + attention_mask_params_base = attention_mask_params + if attention_mask_params is None: + attention_mask_params = AttentionMaskParams() + if attention_mask_params_base is None: + attention_mask_params_base = AttentionMaskParams() + self.attention_mask_params = attention_mask_params + self.attention_mask_params_base = attention_mask_params_base + self.weight_control_sample = weight_control_sample + self.use_controlnet_mask = use_controlnet_mask + self.random_mask_shift = random_mask_shift + self.random_mask = random_mask + self.use_resampler = use_resampler + self.unet_from_pipe = unet_from_pipe + self.unet_operates_on_2d = unet_operates_on_2d + self.image_encoder = image_encoder_obj + self.use_standard_attention_processor = use_standard_attention_processor + self.num_frames_before_chunk = num_frames_before_chunk + self.resampler_type = resampler_type + self.resampler_cls = resampler_cls + self.resampler_merging_layers = resampler_merging_layers + self.cfg_text_image = cfg_text_image + self.aggregation = aggregation + self.resampler_random_shift = resampler_random_shift + self.img_cond_alpha_per_frame = img_cond_alpha_per_frame + self.use_image_encoder_normalization = use_image_encoder_normalization + self.use_of = use_of + self.ema_param = ema_param + self.concat = concat + self.use_image_tokens_main = use_image_tokens_main + self.use_image_tokens_ctrl = use_image_tokens_ctrl + assert not use_post_merger_zero_conv + + if spatial_latent_input: + assert ( + unet_from_diffusers + ), "Spatial latent input only implemented by original diffusers model. Set 'model.unet_params.unet_from_diffusers=True'." diff --git a/src/videogen_hub/pipelines/streamingt2v/model/requires_grad_setter.py b/src/videogen_hub/pipelines/streamingt2v/model/requires_grad_setter.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6cb997ac3b91c8a76390b5885985678bb402cb --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/requires_grad_setter.py @@ -0,0 +1,36 @@ +from typing import Union, Any, Dict, List, Optional, Tuple +import pytorch_lightning as pl + + +class LayerConfig(): + def __init__(self, + gradient_setup: List[Tuple[bool, List[str]]] = None, + ) -> None: + + if gradient_setup is not None: + self.gradient_setup = gradient_setup + self.new_config = True + # TODO add option to specify quantization per layer + + def set_requires_grad(self, pl_module: pl.LightningModule): + # [["True","unet.a.b","c"],["True,[]"]] + + for selected_module_setup in self.gradient_setup: + for model_name, p in pl_module.named_parameters(): + grad_mode = selected_module_setup[0] == True + selected_module_path = selected_module_setup[1] + path_is_matching = True + model_name_selection = model_name + for selected_module in selected_module_path: + position = model_name_selection.find(selected_module) + if position == -1: + path_is_matching = False + continue + else: + shift = len(selected_module) + model_name_selection = model_name_selection[position+shift:] + if path_is_matching: + # if grad_mode: + # print( + # f"Setting gradient for {model_name} to {grad_mode}") + p.requires_grad = grad_mode diff --git a/src/videogen_hub/pipelines/streamingt2v/model/video_ldm.py b/src/videogen_hub/pipelines/streamingt2v/model/video_ldm.py new file mode 100644 index 0000000000000000000000000000000000000000..8983ea2bbe719e91a1feaf6b8841900499255c79 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/video_ldm.py @@ -0,0 +1,487 @@ +from pathlib import Path +from typing import Any, Optional, Union, Callable + +import pytorch_lightning as pl +import torch +from diffusers import DDPMScheduler, DiffusionPipeline, AutoencoderKL, DDIMScheduler +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange, repeat + +from transformers import CLIPTextModel, CLIPTokenizer +from videogen_hub.pipelines.streamingt2v.utils.video_utils import ( + ResultProcessor, + save_videos_grid, + video_naming, +) + +from . import pl_module_params_controlnet + +from .diffusers_conditional.models.controlnet.controlnet import ( + ControlNetModel, +) +from .diffusers_conditional.models.controlnet.unet_3d_condition import ( + UNet3DConditionModel, +) +from .diffusers_conditional.models.controlnet.pipeline_text_to_video_w_controlnet_synth import ( + TextToVideoSDPipeline, +) + +from .diffusers_conditional.models.controlnet.processor import ( + set_use_memory_efficient_attention_xformers, +) +from .diffusers_conditional.models.controlnet.mask_generator import ( + MaskGenerator, +) + +import warnings + +# from warnings import warn +from videogen_hub.pipelines.streamingt2v.utils.iimage import IImage +from videogen_hub.pipelines.streamingt2v.utils.object_loader import instantiate_object +from videogen_hub.pipelines.streamingt2v.utils.object_loader import get_class + + +class VideoLDM(pl.LightningModule): + + def __init__( + self, + inference_params: pl_module_params_controlnet.InferenceParams, + opt_params: pl_module_params_controlnet.OptimizerParams = None, + unet_params: pl_module_params_controlnet.UNetParams = None, + ): + super().__init__() + + self.inference_generator = torch.Generator(device=self.device) + + self.opt_params = opt_params + self.unet_params = unet_params + + print(f"Base pipeline from: {unet_params.pipeline_repo}") + print(f"Pipeline class {unet_params.pipeline_class}") + # load entire pipeline (unet, vq, text encoder,..) + state_dict_control_model = None + state_dict_fusion = None + state_dict_base_model = None + + if len(opt_params.load_trained_controlnet_from_ckpt) > 0: + state_dict_ckpt = torch.load( + opt_params.load_trained_controlnet_from_ckpt, + map_location=torch.device("cpu"), + ) + state_dict_ckpt = state_dict_ckpt["state_dict"] + state_dict_control_model = dict( + filter(lambda x: x[0].startswith("unet"), state_dict_ckpt.items()) + ) + state_dict_control_model = { + k.split("unet.")[1]: v for (k, v) in state_dict_control_model.items() + } + + state_dict_fusion = dict( + filter( + lambda x: "cross_attention_merger" in x[0], state_dict_ckpt.items() + ) + ) + state_dict_fusion = { + k.split("base_model.")[1]: v for (k, v) in state_dict_fusion.items() + } + del state_dict_ckpt + + state_dict_proj = None + state_dict_ckpt = None + + if hasattr(unet_params, "use_resampler") and unet_params.use_resampler: + num_queries = unet_params.num_frames if unet_params.num_frames > 1 else None + if unet_params.use_image_tokens_ctrl: + num_queries = unet_params.num_control_input_frames + assert unet_params.frame_expansion == "none" + image_encoder = self.unet_params.image_encoder + embedding_dim = image_encoder.embedding_dim + + resampler = instantiate_object( + self.unet_params.resampler_cls, + video_length=num_queries, + embedding_dim=embedding_dim, + input_tokens=image_encoder.num_tokens, + num_layers=self.unet_params.resampler_merging_layers, + aggregation=self.unet_params.aggregation, + ) + + state_dict_proj = None + + self.resampler = resampler + self.image_encoder = image_encoder + + noise_scheduler = DDPMScheduler.from_pretrained( + self.unet_params.pipeline_repo, subfolder="scheduler" + ) + tokenizer = CLIPTokenizer.from_pretrained( + self.unet_params.pipeline_repo, subfolder="tokenizer" + ) + text_encoder = CLIPTextModel.from_pretrained( + self.unet_params.pipeline_repo, subfolder="text_encoder" + ) + vae = AutoencoderKL.from_pretrained( + self.unet_params.pipeline_repo, subfolder="vae" + ) + base_model = UNet3DConditionModel.from_pretrained( + self.unet_params.pipeline_repo, + subfolder="unet", + low_cpu_mem_usage=False, + device_map=None, + merging_mode=self.unet_params.merging_mode_base, + use_image_embedding=unet_params.use_resampler + and unet_params.use_image_tokens_main, + use_fps_conditioning=self.opt_params.use_fps_conditioning, + unet_params=unet_params, + ) + + if state_dict_base_model is not None: + miss, unex = base_model.load_state_dict(state_dict_base_model, strict=False) + assert len(unex) == 0 + if len(miss) > 0: + warnings.warn(f"Missing keys when loading base_mode:{miss}") + del state_dict_base_model + if state_dict_fusion is not None: + miss, unex = base_model.load_state_dict(state_dict_fusion, strict=False) + assert len(unex) == 0 + del state_dict_fusion + + print("PIPE LOADING DONE") + self.noise_scheduler = noise_scheduler + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + + self.unet = ControlNetModel.from_unet( + unet=base_model, + conditioning_embedding_out_channels=unet_params.conditioning_embedding_out_channels, + downsample_controlnet_cond=unet_params.downsample_controlnet_cond, + num_frames=( + unet_params.num_frames + if ( + unet_params.frame_expansion != "none" + or self.unet_params.use_controlnet_mask + ) + else unet_params.num_control_input_frames + ), + num_frame_conditioning=unet_params.num_control_input_frames, + frame_expansion=unet_params.frame_expansion, + pre_transformer_in_cond=unet_params.pre_transformer_in_cond, + num_tranformers=unet_params.num_tranformers, + vae=AutoencoderKL.from_pretrained( + self.unet_params.pipeline_repo, subfolder="vae" + ), + zero_conv_mode=unet_params.zero_conv_mode, + merging_mode=unet_params.merging_mode, + condition_encoder=unet_params.condition_encoder, + use_controlnet_mask=unet_params.use_controlnet_mask, + use_image_embedding=unet_params.use_resampler + and unet_params.use_image_tokens_ctrl, + unet_params=unet_params, + use_image_encoder_normalization=unet_params.use_image_encoder_normalization, + ) + if state_dict_control_model is not None: + miss, unex = self.unet.load_state_dict( + state_dict_control_model, strict=False + ) + if len(miss) > 0: + print("WARNING: Loading checkpoint for controlnet misses states") + print(miss) + + if unet_params.frame_expansion == "none": + attention_params = self.unet_params.attention_mask_params + assert ( + not attention_params.temporal_self_attention_only_on_conditioning + and not attention_params.spatial_attend_on_condition_frames + and not attention_params.temp_attend_on_neighborhood_of_condition_frames + ) + + self.mask_generator = MaskGenerator( + self.unet_params.attention_mask_params, + num_frame_conditioning=self.unet_params.num_control_input_frames, + num_frames=self.unet_params.num_frames, + ) + self.mask_generator_base = MaskGenerator( + self.unet_params.attention_mask_params_base, + num_frame_conditioning=self.unet_params.num_control_input_frames, + num_frames=self.unet_params.num_frames, + ) + + if state_dict_proj is not None and unet_params.use_image_tokens_main: + if unet_params.use_image_tokens_main: + missing, unexpected = base_model.load_state_dict( + state_dict_proj, strict=False + ) + elif unet_params.use_image_tokens_ctrl: + missing, unexpected = unet.load_state_dict( + state_dict_proj, strict=False + ) + assert len(unexpected) == 0, f"Unexpected entries {unexpected}" + print(f"Missing keys state proj = {missing}") + del state_dict_proj + + base_model.requires_grad_(False) + self.base_model = base_model + self.unet.requires_grad_(False) + self.text_encoder.requires_grad_(False) + self.vae.requires_grad_(False) + + layers_config = opt_params.layers_config + layers_config.set_requires_grad(self) + + print("CUSTOM XFORMERS ATTENTION USED.") + if is_xformers_available(): + set_use_memory_efficient_attention_xformers( + self.unet, + num_frame_conditioning=self.unet_params.num_control_input_frames, + num_frames=self.unet_params.num_frames, + attention_mask_params=self.unet_params.attention_mask_params, + ) + set_use_memory_efficient_attention_xformers( + self.base_model, + num_frame_conditioning=self.unet_params.num_control_input_frames, + num_frames=self.unet_params.num_frames, + attention_mask_params=self.unet_params.attention_mask_params_base, + ) + + if len(inference_params.scheduler_cls) > 0: + inf_scheduler_class = get_class(inference_params.scheduler_cls) + else: + inf_scheduler_class = DDIMScheduler + + inf_scheduler = inf_scheduler_class.from_pretrained( + self.unet_params.pipeline_repo, subfolder="scheduler" + ) + inference_pipeline = TextToVideoSDPipeline( + vae=self.vae, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + unet=self.base_model, + controlnet=self.unet, + scheduler=inf_scheduler, + ) + + inference_pipeline.set_noise_generator(self.opt_params.noise_generator) + inference_pipeline.enable_vae_slicing() + + inference_pipeline.set_progress_bar_config(disable=True) + + self.inference_params = inference_params + self.inference_pipeline = inference_pipeline + + self.result_processor = ResultProcessor( + fps=self.inference_params.frame_rate, + n_frames=self.inference_params.video_length, + ) + + def on_start(self): + datamodule = self.trainer._data_connector._datahook_selector.datamodule + pipe_id_model = self.unet_params.pipeline_repo + for dataset_key in ["video_dataset", "image_dataset", "predict_dataset"]: + dataset = getattr(datamodule, dataset_key, None) + if dataset is not None and hasattr(dataset, "model_id"): + pipe_id_data = dataset.model_id + assert ( + pipe_id_model == pipe_id_data + ), f"Model and Dataloader need the same pipeline path. Found '{pipe_id_model}' and '{dataset_key}.model_id={pipe_id_data}'. Consider setting '--data.{dataset_key}.model_id={pipe_id_data}'" + self.result_processor.set_logger(self.logger) + + def on_predict_start(self) -> None: + self.on_start() + # pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") + # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + # pipe.set_progress_bar_config(disable=True) + # self.first_stage = pipe.to(self.device) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + cfg = self.trainer.predict_cfg + + result_file_stem = cfg["result_file_stem"] + storage_fol = Path(cfg["predict_dir"]) + prompts = [cfg["prompt"]] + + inference_params: pl_module_params_controlnet.InferenceParams = ( + self.inference_params + ) + conditioning_type = inference_params.conditioning_type + # n_autoregressive_generations = inference_params.n_autoregressive_generations + n_autoregressive_generations = cfg["n_autoregressive_generations"] + mode = inference_params.mode + start_from_real_input = inference_params.start_from_real_input + assert isinstance(prompts, list) + + prompts = n_autoregressive_generations * prompts + + self.inference_generator.manual_seed(self.inference_params.seed) + + assert ( + self.unet_params.num_control_input_frames + == self.inference_params.video_length // 2 + ), f"currently we assume to have an equal size for and second half of the frame interval, e.g. 16 frames, and we condition on 8. Current setup: {self.unet_params.num_frame_conditioning} and {self.inference_params.video_length}" + + chunks_conditional = [] + batch_size = 1 + shape = ( + batch_size, + self.inference_pipeline.unet.config.in_channels, + self.inference_params.video_length, + self.inference_pipeline.unet.config.sample_size, + self.inference_pipeline.unet.config.sample_size, + ) + for idx, prompt in enumerate(prompts): + if idx > 0: + content = sample * 2 - 1 + content_latent = ( + self.vae.encode(content).latent_dist.sample() + * self.vae.config.scaling_factor + ) + content_latent = rearrange(content_latent, "F C W H -> 1 C F W H") + content_latent = ( + content_latent[:, :, self.unet_params.num_control_input_frames :] + .detach() + .clone() + ) + + if hasattr(self.inference_pipeline, "noise_generator"): + latents = self.inference_pipeline.noise_generator.sample_noise( + shape=shape, + device=self.device, + dtype=self.dtype, + generator=self.inference_generator, + content=content_latent if idx > 0 else None, + ) + else: + latents = None + if idx == 0: + sample = cfg["video"].to(self.device) + else: + if inference_params.conditioning_type == "fixed": + context = chunks_conditional[0][ + : self.unet_params.num_frame_conditioning + ] + context = [context] + context = [2 * sample - 1 for sample in context] + + input_frames_conditioning = torch.cat(context).detach().clone() + input_frames_conditioning = rearrange( + input_frames_conditioning, "F C W H -> 1 F C W H" + ) + elif inference_params.conditioning_type == "last_chunk": + input_frames_conditioning = ( + condition_input[:, -self.unet_params.num_frame_conditioning :] + .detach() + .clone() + ) + elif inference_params.conditioning_type == "past": + context = [ + sample[: self.unet_params.num_control_input_frames] + for sample in chunks_conditional + ] + context = [2 * sample - 1 for sample in context] + + input_frames_conditioning = torch.cat(context).detach().clone() + input_frames_conditioning = rearrange( + input_frames_conditioning, "F C W H -> 1 F C W H" + ) + else: + raise NotImplementedError() + + input_frames = ( + condition_input[:, self.unet_params.num_control_input_frames :] + .detach() + .clone() + ) + + sample = self( + prompt, + input_frames=input_frames, + input_frames_conditioning=input_frames_conditioning, + latents=latents, + ) + + if hasattr(self.inference_pipeline, "reset_noise_generator_state"): + self.inference_pipeline.reset_noise_generator_state() + + condition_input = rearrange(sample, "F C W H -> 1 F C W H") + condition_input = (2 * condition_input) - 1 # range: [-1,1] + + # store first 16 frames, then always last 8 of a chunk + chunks_conditional.append(sample) + + result_formats = self.inference_params.result_formats + # result_formats = [gif", "mp4"] + concat_video = self.inference_params.concat_video + + def IImage_normalized(x): + return IImage(x, vmin=0, vmax=1) + + for result_format in result_formats: + save_format = result_format.replace("eval_", "") + + merged_video = None + for chunk_idx, (prompt, video) in enumerate( + zip(prompts, chunks_conditional) + ): + if chunk_idx == 0: + current_video = IImage_normalized(video) + else: + current_video = IImage_normalized( + video[self.unet_params.num_control_input_frames :] + ) + + if merged_video is None: + merged_video = current_video + else: + merged_video &= current_video + + if concat_video: + filename = video_naming(prompts[0], save_format, batch_idx, 0) + result_file_video = (storage_fol / filename).absolute().as_posix() + result_file_video = ( + Path(result_file_video).parent + / (result_file_stem + Path(result_file_video).suffix) + ).as_posix() + self.result_processor.save_to_file( + video=merged_video.torch(vmin=0, vmax=1), + prompt=prompts[0], + video_filename=result_file_video, + prompt_on_vid=False, + ) + + def forward( + self, prompt, input_frames=None, input_frames_conditioning=None, latents=None + ): + call_params = self.inference_params.to_dict() + print(f"INFERENCE PARAMS = {call_params}") + call_params["prompt"] = prompt + + call_params["image"] = input_frames + call_params["num_frames"] = self.inference_params.video_length + call_params["return_dict"] = False + call_params["output_type"] = "pt_t2v" + call_params["mask_generator"] = self.mask_generator + call_params["precision"] = ( + "16" if self.trainer.precision.startswith("16") else "32" + ) + call_params["no_text_condition_control"] = ( + self.opt_params.no_text_condition_control + ) + call_params["weight_control_sample"] = self.unet_params.weight_control_sample + call_params["use_controlnet_mask"] = self.unet_params.use_controlnet_mask + call_params["skip_controlnet_branch"] = self.opt_params.skip_controlnet_branch + call_params["img_cond_resampler"] = ( + self.resampler if self.unet_params.use_resampler else None + ) + call_params["img_cond_encoder"] = ( + self.image_encoder if self.unet_params.use_resampler else None + ) + call_params["input_frames_conditioning"] = input_frames_conditioning + call_params["cfg_text_image"] = self.unet_params.cfg_text_image + call_params["use_of"] = self.unet_params.use_of + if latents is not None: + call_params["latents"] = latents + + sample = self.inference_pipeline( + generator=self.inference_generator, **call_params + ) + return sample diff --git a/src/videogen_hub/pipelines/streamingt2v/model/video_noise_generator.py b/src/videogen_hub/pipelines/streamingt2v/model/video_noise_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..49ffabe721a44c45242bbd3d7811925b948c3184 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model/video_noise_generator.py @@ -0,0 +1,225 @@ +import torch +import torch.fft as fft +from torch import nn +from torch.nn import functional +from math import sqrt +from einops import rearrange +import math +import numbers +from typing import List + +# adapted from https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10 +# and https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/19 + + +def gaussian_smoothing_kernel(shape, kernel_size, sigma, dim=2): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + + kernel *= torch.exp(-((mgrid - mean) / std) ** 2 / 2) + # kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + # torch.exp(-((mgrid - mean) / std) ** 2 / 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + pad_length = (math.floor( + (shape[-1]-kernel_size[-1])/2), math.floor((shape[-1]-kernel_size[-1])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-3]-kernel_size[-3])/2), math.floor((shape[-3]-kernel_size[-3])/2)) + + kernel = functional.pad(kernel, pad_length) + assert kernel.shape == shape[-3:] + return kernel + + ''' + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = functional.conv1d + elif dim == 2: + self.conv = functional.conv2d + elif dim == 3: + self.conv = functional.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format( + dim) + ) + ''' + + +class NoiseGenerator(): + + def __init__(self, alpha: float = 0.0, shared_noise_across_chunks: bool = False, mode="vanilla", forward_steps: int = 850, radius: List[float] = None) -> None: + self.mode = mode + self.alpha = alpha + self.shared_noise_across_chunks = shared_noise_across_chunks + self.forward_steps = forward_steps + self.radius = radius + + def set_seed(self, seed: int): + self.seed = seed + + def reset_seed(self, seed: int): + pass + + def reset_noise_generator_state(self): + if hasattr(self, "e_shared"): + del self.e_shared + + def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None, content=None): + assert (z_0 is not None) != ( + shape is not None), f"either z_0 must be None, or shape must be None. Both provided." + kwargs = {} + noise = torch.randn(shape, **kwargs) + + if z_0 is None: + if device is not None: + kwargs["device"] = device + if dtype is not None: + kwargs["dtype"] = dtype + + else: + kwargs["device"] = z_0.device + kwargs["dtype"] = z_0.dtype + shape = z_0.shape + + if generator is not None: + kwargs["generator"] = generator + + B, F, C, W, H = shape + if F == 4 and C > 4: + frame_idx = 2 + F, C = C, F + else: + frame_idx = 1 + + if "mixed_noise" in self.mode: + + shape_per_frame = [dim for dim in shape] + shape_per_frame[frame_idx] = 1 + zero_mean = torch.zeros( + shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"]) + std = torch.ones( + shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"]) + alpha = self.alpha + std_coeff_shared = (alpha**2) / (1 + alpha**2) + if self.shared_noise_across_chunks and hasattr(self, "e_shared"): + e_shared = self.e_shared + else: + e_shared = torch.normal(mean=zero_mean, std=sqrt( + std_coeff_shared)*std, generator=kwargs["generator"] if "generator" in kwargs else None) + if self.shared_noise_across_chunks: + self.e_shared = e_shared + + e_inds = [] + for frame in range(shape[frame_idx]): + std_coeff_ind = 1 / (1 + alpha**2) + e_ind = torch.normal( + mean=zero_mean, std=sqrt(std_coeff_ind)*std, generator=kwargs["generator"] if "generator" in kwargs else None) + e_inds.append(e_ind) + noise = torch.cat( + [e_shared + e_ind for e_ind in e_inds], dim=frame_idx) + + if "consistI2V" in self.mode and content is not None: + # if self.mode == "mixed_noise_consistI2V", we will use 'noise' from 'mixed_noise'. Otherwise, it is randn noise. + + if frame_idx == 1: + assert content.shape[0] == noise.shape[0] and content.shape[2:] == noise.shape[2:] + content = torch.concat([content, content[:, -1:].repeat( + 1, noise.shape[1]-content.shape[1], 1, 1, 1)], dim=1) + noise = rearrange(noise, "B F C W H -> (B C) F W H") + content = rearrange(content, "B F C W H -> (B C) F W H") + + else: + assert content.shape[:2] == noise.shape[: + 2] and content.shape[3:] == noise.shape[3:] + content = torch.concat( + [content, content[:, :, -1:].repeat(1, 1, noise.shape[2]-content.shape[2], 1, 1)], dim=2) + noise = rearrange(noise, "B C F W H -> (B C) F W H") + content = rearrange(content, "B C F W H -> (B C) F W H") + + # TODO implement DDPM_forward using diffusers framework + ''' + content_noisy = ddpm_forward( + content, noise, self.forward_steps) + ''' + + # A 2D low pass filter was given in the blog: + # see https://pytorch.org/blog/the-torch.fft-module-accelerated-fast-fourier-transforms-with-autograd-in-pyTorch/ + + # alternative + # do we have to specify more (s,dim,norm?) + noise_fft = fft.fftn(noise) + content_noisy_fft = fft.fftn(content_noisy) + + # shift low frequency parts to center + noise_fft_shifted = fft.fftshift(noise_fft) + content_noisy_fft_shifted = fft.fftshift(content_noisy_fft) + + # create gaussian low pass filter 'gaussian_low_pass_filter' (specify std!) + # mask out high frequencies using 'cutoff_frequence', something like gaussian_low_pass_filter[freq > cut_off_frequency] = 0.0 + # TODO define 'gaussian_low_pass_filter', apply frequency cutoff filter using self.cutoff_frequency. We need to apply fft.fftshift too probably. + # TODO what exactly is the "normalized space-time stop frequency" used for the cutoff? + + gaussian_3d = gaussian_smoothing_kernel(noise_fft.shape, kernel_size=( + noise_fft.shape[-3], noise_fft.shape[-2], noise_fft.shape[-1]), sigma=1, dim=3).to(noise.device) + + # define cutoff frequency around the kernel center + # TODO define center and cut off radius, e.g. somethink like gaussian_3d[...,:c_x-r_x,:c_y-r_y:,:c_z-r_z] = 0.0 and gaussian_3d[...,c_x+r_x:,c_y+r_y:,c_z+r_z:] = 0.0 + # as we have 16 x 32 x 32, center should be (7.5,15.5,15.5) + radius = self.radius + + # TODO we need to use rounding (ceil?) + + gaussian_3d[:center[0]-radius[0], :center[1] - + radius[1], :center[2]-radius[2]] = 0.0 + gaussian_3d[center[0]+radius[0]:, + center[1]+radius[1]:, center[2]+radius[2]:] = 0.0 + + noise_fft_shifted_hp = noise_fft_shifted * (1 - gaussian_3d) + content_noisy_fft_shifted_lp = content_noisy_fft_shifted * gaussian_3d + + noise = fft.ifftn(fft.ifftshift( + noise_fft_shifted_hp+content_noisy_fft_shifted_lp)) + if frame_idx == 1: + noise = rearrange( + noise, "(B C) F W H -> B F C W H", B=B) + else: + noise = rearrange( + noise, "(B C) F W H -> B C F W H", B=B) + + assert noise.shape == shape + return noise diff --git a/src/videogen_hub/pipelines/streamingt2v/model_func.py b/src/videogen_hub/pipelines/streamingt2v/model_func.py new file mode 100644 index 0000000000000000000000000000000000000000..81cf7eefa3bead2e265d3168fa43eb16788dc11a --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model_func.py @@ -0,0 +1,231 @@ +# General +import os +from os.path import join as opj +import datetime +import torch +from einops import rearrange, repeat + +# Utilities +from videogen_hub.pipelines.streamingt2v.inference_utils import * + +from modelscope.outputs import OutputKeys +import imageio +from PIL import Image +import numpy as np + +import torch.nn.functional as F +import torchvision.transforms as transforms +from diffusers.utils import load_image + +transform = transforms.Compose([transforms.PILToTensor()]) + + +def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"): + frames = ms_model( + prompt, + num_inference_steps=t, + generator=inference_generator, + eta=1.0, + height=256, + width=256, + latents=None, + ).frames + frames = torch.stack([torch.from_numpy(frame) for frame in frames]) + frames = frames.to(device).to(torch.float32) + return rearrange(frames[0], "F W H C -> F C W H") + + +def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"): + frames = ad_model( + prompt, + negative_prompt="bad quality, worse quality", + num_frames=16, + num_inference_steps=t, + generator=inference_generator, + guidance_scale=7.5, + ).frames[0] + frames = torch.stack([transform(frame) for frame in frames]) + frames = frames.to(device).to(torch.float32) + frames = F.interpolate(frames, size=256) + frames = frames / 255.0 + return frames + + +def sdxl_image_gen(prompt, sdxl_model): + image = sdxl_model(prompt=prompt).images[0] + return image + + +def svd_short_gen( + image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda" +): + if image is None or image == "": + image = sdxl_image_gen(prompt, sdxl_model) + image = image.resize((576, 576)) + image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) + elif type(image) is str: + image = load_image(image) + image = resize_and_keep(image) + image = center_crop(image) + image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) + else: + image = Image.fromarray(np.uint8(image)) + image = resize_and_keep(image) + image = center_crop(image) + image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) + + frames = svd_model( + image, decode_chunk_size=8, generator=inference_generator + ).frames[0] + frames = torch.stack([transform(frame) for frame in frames]) + frames = frames.to(device).to(torch.float32) + frames = frames[:16, :, :, 224:-224] + frames = F.interpolate(frames, size=256) + frames = frames / 255.0 + return frames + + +def stream_long_gen( + prompt, + short_video, + n_autoreg_gen, + negative_prompt, + seed, + t, + image_guidance, + result_file_stem, + stream_cli, + stream_model, +): + trainer = stream_cli.trainer + trainer.limit_predict_batches = 1 + + trainer.predict_cfg = { + "predict_dir": stream_cli.config["result_fol"].as_posix(), + "result_file_stem": result_file_stem, + "prompt": prompt, + "video": short_video, + "seed": seed, + "num_inference_steps": t, + "guidance_scale": image_guidance, + "n_autoregressive_generations": n_autoreg_gen, + } + stream_model.inference_params.negative_prompt = negative_prompt + trainer.predict(model=stream_model, datamodule=stream_cli.datamodule) + + +def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True): + downscale = cfg_v2v["downscale"] + upscale_size = cfg_v2v["upscale_size"] + pad = cfg_v2v["pad"] + + now = datetime.datetime.now() + now = str(now.time()).replace(":", "_").replace(".", "_") + name = prompt[:100].replace(" ", "_") + "_" + now + enhanced_video_mp4 = opj(where_to_log, name + "_enhanced.mp4") + + video_frames = imageio.mimread(video) + h, w, _ = video_frames[0].shape + + # Downscale video, then resize to fit the upscale size + video = [ + Image.fromarray(frame).resize((w // downscale, h // downscale)) + for frame in video_frames + ] + video = [resize_to_fit(frame, upscale_size) for frame in video] + + if pad: + video = [pad_to_fit(frame, upscale_size) for frame in video] + # video = [np.array(frame) for frame in video] + + imageio.mimsave(opj(where_to_log, "temp_" + now + ".mp4"), video, fps=8) + + p_input = { + "video_path": opj(where_to_log, "temp_" + now + ".mp4"), + "text": prompt, + "positive_prompt": prompt, + "total_noise_levels": 600, + } + model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO] + + # Remove padding + video_frames = imageio.mimread(enhanced_video_mp4) + video_frames_square = [] + for frame in video_frames: + frame = frame[:, 280:-280, :] + video_frames_square.append(frame) + imageio.mimsave(enhanced_video_mp4, video_frames_square) + + return enhanced_video_mp4 + + +# The main functionality for video to video +def video2video_randomized( + prompt, + video, + where_to_log, + cfg_v2v, + model_v2v, + square=True, + chunk_size=24, + overlap_size=8, + negative_prompt="", +): + downscale = cfg_v2v["downscale"] + upscale_size = cfg_v2v["upscale_size"] + pad = cfg_v2v["pad"] + + now = datetime.datetime.now() + name = ( + prompt[:100].replace(" ", "_") + + "_" + + str(now.time()).replace(":", "_").replace(".", "_") + ) + enhanced_video_mp4 = opj(where_to_log, name + "_enhanced.mp4") + + video_frames = imageio.mimread(video) + h, w, _ = video_frames[0].shape + + n_chunks = (len(video_frames) - overlap_size) // (chunk_size - overlap_size) + trim_length = n_chunks * (chunk_size - overlap_size) + overlap_size + if trim_length < chunk_size: + raise ValueError( + f"Chunk size [{chunk_size}] cannot be larger than the number of frames in the video [{len(video_frames)}], please provide smaller chunk size" + ) + if trim_length < len(video_frames): + print( + "Video cannot be processed with chunk size {chunk_size} and overlap size {overlap_size}, " + "trimming it to length {trim_length} to be able to process it" + ) + video_frames = video_frames[:trim_length] + + model_v2v.chunk_size = chunk_size + model_v2v.overlap_size = overlap_size + + # Downscale video, then resize to fit the upscale size + video = [ + Image.fromarray(frame).resize((w // downscale, h // downscale)) + for frame in video_frames + ] + video = [resize_to_fit(frame, upscale_size) for frame in video] + + if pad: + video = [pad_to_fit(frame, upscale_size) for frame in video] + + video = list(map(np.array, video)) + + imageio.mimsave(opj(where_to_log, "temp.mp4"), video, fps=8) + + p_input = { + "video_path": opj(where_to_log, "temp.mp4"), + "text": prompt, + "positive_prompt": "", + "negative_prompt": negative_prompt, + "total_noise_levels": 600, + } + + output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[ + OutputKeys.OUTPUT_VIDEO + ] + + return enhanced_video_mp4 diff --git a/src/videogen_hub/pipelines/streamingt2v/model_init.py b/src/videogen_hub/pipelines/streamingt2v/model_init.py new file mode 100644 index 0000000000000000000000000000000000000000..67f8ea5c9347a4415b9e0f209003103b59fa4203 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/model_init.py @@ -0,0 +1,164 @@ +# General +import sys +from pathlib import Path +import torch +from pytorch_lightning import LightningDataModule + +# For Stage-1 +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler +from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter +from diffusers import StableVideoDiffusionPipeline, AutoPipelineForText2Image + +# For Stage-2 +import tempfile +import yaml +from videogen_hub.pipelines.streamingt2v.model.video_ldm import VideoLDM +from videogen_hub.pipelines.streamingt2v.model.callbacks import SaveConfigCallback +from videogen_hub.pipelines.streamingt2v.inference_utils import ( + legacy_transformation, + remove_value, + CustomCLI, + v2v_to_device, +) + +# For Stage-3 +import sys + +sys.path.append(Path(__file__).parent / "thirdparty") + + +# Initialize Stage-1 model1. +def init_modelscope(device="cuda"): + pipe = DiffusionPipeline.from_pretrained( + "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" + ) + # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + # pipe.set_progress_bar_config(disable=True) + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.enable_vae_slicing() + pipe.set_progress_bar_config(disable=True) + return pipe.to(device) + + +def init_zeroscope(device="cuda"): + pipe = DiffusionPipeline.from_pretrained( + "cerspense/zeroscope_v2_576w", torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + return pipe.to(device) + + +def init_animatediff(device="cuda"): + adapter = MotionAdapter.from_pretrained( + "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16 + ) + model_id = "SG161222/Realistic_Vision_V5.1_noVAE" + pipe = AnimateDiffPipeline.from_pretrained( + model_id, motion_adapter=adapter, torch_dtype=torch.float16 + ) + scheduler = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + clip_sample=False, + timestep_spacing="linspace", + beta_schedule="linear", + steps_offset=1, + ) + pipe.scheduler = scheduler + pipe.enable_vae_slicing() + pipe.enable_model_cpu_offload() + return pipe.to(device) + + +def init_sdxl(device="cuda"): + pipe = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + variant="fp16", + use_safetensors=True, + ) + # pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True) + return pipe.to(device) + + +def init_svd(device="cuda"): + pipe = StableVideoDiffusionPipeline.from_pretrained( + "stabilityai/stable-video-diffusion-img2vid-xt", + torch_dtype=torch.float16, + variant="fp16", + ) + pipe.enable_model_cpu_offload() + return pipe.to(device) + + +# Initialize StreamingT2V model. +def init_streamingt2v_model(ckpt_file, result_fol, device): + accelerator = "gpu" if device.startswith("cuda") else "cpu" + import os + + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + # print("base dir", base_dir) + config_file = f"{base_dir}/streamingt2v/configs/text_to_video/config.yaml" + print("config dir", config_file) + sys.argv = sys.argv[:1] + with tempfile.TemporaryDirectory() as tmpdirname: + storage_fol = Path(tmpdirname) + with open(config_file, "r") as yaml_handle: + yaml_obj = yaml.safe_load(yaml_handle) + + yaml_obj_orig_data_cfg = legacy_transformation(yaml_obj) + yaml_obj_orig_data_cfg = remove_value(yaml_obj_orig_data_cfg, "video_dataset") + + with open(storage_fol / "config.yaml", "w") as outfile: + yaml.dump(yaml_obj_orig_data_cfg, outfile, default_flow_style=False) + sys.argv.append("--config") + sys.argv.append((storage_fol / "config.yaml").as_posix()) + sys.argv.append("--ckpt") + sys.argv.append(ckpt_file.as_posix()) + sys.argv.append("--result_fol") + sys.argv.append(result_fol.as_posix()) + sys.argv.append("--config") + sys.argv.append("configs/inference/inference_long_video.yaml") + sys.argv.append("--data.prompt_cfg.type=prompt") + sys.argv.append(f"--data.prompt_cfg.content='test prompt for initialization'") + sys.argv.append(f"--trainer.accelerator={accelerator}") + sys.argv.append("--trainer.devices=1") + sys.argv.append("--trainer.num_nodes=1") + sys.argv.append(f"--model.inference_params.num_inference_steps=50") + sys.argv.append(f"--model.inference_params.n_autoregressive_generations=4") + sys.argv.append("--model.inference_params.concat_video=True") + sys.argv.append("--model.inference_params.result_formats=[eval_mp4]") + + cli = CustomCLI( + VideoLDM, + LightningDataModule, + run=False, + subclass_mode_data=True, + auto_configure_optimizers=False, + parser_kwargs={"parser_mode": "omegaconf"}, + save_config_callback=SaveConfigCallback, + save_config_kwargs={"log_dir": result_fol, "overwrite": True}, + ) + + model = cli.model + model.load_state_dict( + torch.load(cli.config["ckpt"].as_posix(), map_location=torch.device("cpu"))[ + "state_dict" + ] + ) + return cli, model + + +# Initialize Stage-3 model. +def init_v2v_model(cfg, device): + from modelscope.pipelines import pipeline + + model_id = cfg["model_id"] + pipe_enhance = pipeline( + task="video-to-video", model=model_id, model_revision="v1.1.0", device="cpu" + ) + pipe_enhance.model.cfg.max_frames = 10000 + pipe_enhance = v2v_to_device(pipe_enhance, device) + return pipe_enhance diff --git a/src/videogen_hub/pipelines/streamingt2v/streamingt2v_pipeline.py b/src/videogen_hub/pipelines/streamingt2v/streamingt2v_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..497f58efb16d137a14a301a7e166eb8bf3f05803 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/streamingt2v_pipeline.py @@ -0,0 +1,225 @@ +# General +import os +from os.path import join as opj +import argparse +import datetime +from pathlib import Path +import torch +import gradio as gr +import tempfile +import yaml + +# from t2v_enhanced.model.video_ldm import VideoLDM +from typing import List, Optional + +# from model.callbacks import SaveConfigCallback +from PIL.Image import Image, fromarray + +# from einops import rearrange, repeat + +import sys + +from ... import MODEL_PATH + +sys.path.append("thirdparty") +# from modelscope.pipelines import pipeline +# from modelscope.outputs import OutputKeys +import imageio +import pathlib +import numpy as np + +# Utilities +from .inference_utils import * + +from .model_init import ( + init_modelscope, + init_animatediff, + init_svd, + init_sdxl, + init_v2v_model, + init_streamingt2v_model, +) +from .model_func import * + + +def pipeline(prompt, size, seconds, fps, seed): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt", + type=str, + default=prompt, + help="The prompt to guide video generation.", + ) + parser.add_argument( + "--image", type=str, default="", help="Path to image conditioning." + ) + # parser.add_argument('--video', type=str, default="", help="Path to video conditioning.") + parser.add_argument( + "--base_model", + type=str, + default="ModelscopeT2V", + help="Base model to generate first chunk from", + choices=["ModelscopeT2V", "AnimateDiff", "SVD"], + ) + parser.add_argument( + "--num_frames", + type=int, + default=seconds * fps, + help="The number of video frames to generate.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="", + help="The prompt to guide what to not include in video generation.", + ) + parser.add_argument( + "--negative_prompt_enhancer", + type=str, + default=None, + help="The prompt to guide what to not include in video enhancement. " + "By default is the same as --negative_prompt", + ) + parser.add_argument( + "--num_steps", type=int, default=50, help="The number of denoising steps." + ) + parser.add_argument( + "--image_guidance", type=float, default=9.0, help="The guidance scale." + ) + + parser.add_argument( + "--output_dir", + type=str, + default="results", + help="Path where to save the generated videos.", + ) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--seed", type=int, default=seed, help="Random seed") + + parser.add_argument( + "--chunk", type=int, default=24, help="chunk_size for randomized blending" + ) + parser.add_argument( + "--overlap", type=int, default=8, help="overlap_size for randomized blending" + ) + + parser.add_argument( + "--offload_models", + action="store_true", + help="Load/Offload models to gpu/cpu before and after inference", + ) + args = parser.parse_args() + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + result_fol = Path(args.output_dir).absolute() + device = args.device + + # -------------------------- + # ----- Configurations ----- + # -------------------------- + ckpt_file_streaming_t2v = os.path.join(MODEL_PATH, "streamingtv2", "streaming_t2v.ckpt") + cfg_v2v = { + "downscale": 1, + "upscale_size": size, + "model_id": "damo/Video-to-Video", + "pad": True, + } + + # -------------------------- + # ----- Initialization ----- + # -------------------------- + if args.base_model == "ModelscopeT2V": + if args.offload_models: + model = init_modelscope("cpu") + else: + model = init_modelscope(device) + elif args.base_model == "AnimateDiff": + if args.offload_models: + model = init_animatediff("cpu") + else: + model = init_animatediff(device) + elif args.base_model == "SVD": + if args.offload_models: + model = init_svd("cpu") + sdxl_model = init_sdxl("cpu") + else: + model = init_svd(device) + sdxl_model = init_sdxl(device) + + if args.offload_models: + msxl_model = init_v2v_model(cfg_v2v, "cpu") + else: + msxl_model = init_v2v_model(cfg_v2v, device) + + stream_cli, stream_model = init_streamingt2v_model( + ckpt_file_streaming_t2v, result_fol, "cuda" + ) + if args.offload_models: + stream_model = st2v_to_device(stream_model, "cpu") + inference_generator = torch.Generator(device="cuda") + + # ------------------ + # ----- Inputs ----- + # ------------------ + now = datetime.datetime.now() + name = ( + args.prompt[:100].replace(" ", "_") + + "_" + + str(now.time()).replace(":", "_").replace(".", "_") + ) + + inference_generator = torch.Generator(device="cuda") + inference_generator.manual_seed(args.seed) + + if args.offload_models: + model = model.to(device) + if args.base_model == "ModelscopeT2V": + short_video = ms_short_gen(args.prompt, model, inference_generator) + elif args.base_model == "AnimateDiff": + short_video = ad_short_gen(args.prompt, model, inference_generator) + elif args.base_model == "SVD": + if args.offload_models: + sdxl_model = sdxl_model.to(device) + short_video = svd_short_gen( + args.image, args.prompt, model, sdxl_model, inference_generator + ) + if args.offload_models: + sdxl_model = sdxl_model.to("cpu") + if args.offload_models: + model = model.to("cpu") + + n_autoreg_gen = (args.num_frames - 8) // 8 + stream_long_gen( + args.prompt, + short_video, + n_autoreg_gen, + args.negative_prompt, + args.seed, + args.num_steps, + args.image_guidance, + name, + stream_cli, + stream_model, + ) + if args.offload_models: + stream_model = st2v_to_device(stream_model, "cpu") + + args.negative_prompt_enhancer = ( + args.negative_prompt_enhancer + if args.negative_prompt_enhancer is not None + else args.negative_prompt + ) + if args.offload_models: + msxl_model = v2v_to_device(msxl_model, device) + return video2video_randomized( + args.prompt, + opj(result_fol, name + ".mp4"), + result_fol, + cfg_v2v, + msxl_model, + chunk_size=args.chunk, + overlap_size=args.overlap, + negative_prompt=args.negative_prompt_enhancer, + ) + # if args.offload_models: + # msxl_model = v2v_to_device(msxl_model, "cpu") diff --git a/src/videogen_hub/pipelines/streamingt2v/utils/__init__.py b/src/videogen_hub/pipelines/streamingt2v/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/streamingt2v/utils/conversions.py b/src/videogen_hub/pipelines/streamingt2v/utils/conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb8ba1b7ebb402aa9fdf77c7bf02f68f0d951ce --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/utils/conversions.py @@ -0,0 +1,48 @@ +from pathlib import Path +import PIL +from PIL import Image +import numpy as np +from dataclasses import dataclass + +# TODO add register new converter so that it is accessible via converters.to_x + +def ensure_class(func, params): + def func_wrapper(function): + def wrapper(self=None, *args, **kwargs): + for key in kwargs: + if key in params: + kwargs[key] = func(kwargs[key]) + if self is not None: + return function(self, *args, **kwargs) + else: + return function(*args, **kwargs) + + return wrapper + + return func_wrapper + + +def as_PIL(img): + if not isinstance(img, PIL.Image.Image): + if isinstance(img, Path): + img = img.as_posix() + if isinstance(img, str): + img = Image.open(img) + elif isinstance(img, np.ndarray): + img = Image.fromarray(img) + + else: + raise NotImplementedError + return img + + +def to_ndarray(input): + if not isinstance(input, np.ndarray): + input = np.array(input) + return input + + +def to_Path(input): + if not isinstance(input, Path): + input = Path(input) + return input diff --git a/src/videogen_hub/pipelines/streamingt2v/utils/iimage.py b/src/videogen_hub/pipelines/streamingt2v/utils/iimage.py new file mode 100644 index 0000000000000000000000000000000000000000..f5aba519cfef6b50105e21f9e932910ef4bab8d2 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/utils/iimage.py @@ -0,0 +1,517 @@ +import io +import math +import os +import PIL.Image +import numpy as np +import imageio.v3 as iio +import warnings + + +import torch +import torchvision.transforms.functional as TF +from scipy.ndimage import binary_dilation, binary_erosion +import cv2 + +import re + +import matplotlib.pyplot as plt +from matplotlib import animation +from IPython.display import HTML, Image, display + + +IMG_THUMBSIZE = None + +def torch2np(x, vmin=-1, vmax=1): + if x.ndim != 4: + # raise Exception("Please only use (B,C,H,W) torch tensors!") + warnings.warn( + "Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!") + if x.ndim == 3: + x = x[None] + if x.ndim == 2: + x = x[None, None] + x = x.detach().cpu().float() + if x.dtype == torch.uint8: + return x.numpy().astype(np.uint8) + elif vmin is not None and vmax is not None: + x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin)) + x = x.permute(0, 2, 3, 1).to(torch.uint8) + return x.numpy() + else: + raise NotImplementedError() + + +class IImage: + ''' + Generic media storage. Can store both images and videos. + Stores data as a numpy array by default. + Can be viewed in a jupyter notebook. + ''' + @staticmethod + def open(path): + + iio_obj = iio.imopen(path, 'r') + data = iio_obj.read() + try: + # .properties() does not work for images but for gif files + if not iio_obj.properties().is_batch: + data = data[None] + except AttributeError as e: + # this one works for gif files + if not "duration" in iio_obj.metadata(): + data = data[None] + if data.ndim == 3: + data = data[..., None] + image = IImage(data) + image.link = os.path.abspath(path) + return image + + @staticmethod + def normalized(x, dims=[-1, -2]): + x = (x - x.amin(dims, True)) / \ + (x.amax(dims, True) - x.amin(dims, True)) + return IImage(x, 0) + + def numpy(self): return self.data + + def torch(self, vmin=-1, vmax=1): + if self.data.ndim == 3: + data = self.data.transpose(2, 0, 1) / 255. + else: + data = self.data.transpose(0, 3, 1, 2) / 255. + return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin) + + def cuda(self): + self.device = 'cuda' + return self + + def cpu(self): + self.device = 'cpu' + return self + + def pil(self): + ans = [] + for x in self.data: + if x.shape[-1] == 1: + x = x[..., 0] + + ans.append(PIL.Image.fromarray(x)) + if len(ans) == 1: + return ans[0] + return ans + + def is_iimage(self): + return True + + @property + def shape(self): return self.data.shape + @property + def size(self): return (self.data.shape[-2], self.data.shape[-3]) + + def setFps(self, fps): + self.fps = fps + self.generate_display() + return self + + def __init__(self, x, vmin=-1, vmax=1, fps=None): + if isinstance(x, PIL.Image.Image): + self.data = np.array(x) + if self.data.ndim == 2: + self.data = self.data[..., None] # (H,W,C) + self.data = self.data[None] # (B,H,W,C) + elif isinstance(x, IImage): + self.data = x.data.copy() # Simple Copy + elif isinstance(x, np.ndarray): + self.data = x.copy().astype(np.uint8) + if self.data.ndim == 2: + self.data = self.data[None, ..., None] + if self.data.ndim == 3: + warnings.warn( + "Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)") + self.data = self.data[None] + elif isinstance(x, torch.Tensor): + self.data = torch2np(x, vmin, vmax) + self.display_str = None + self.device = 'cpu' + self.fps = fps if fps is not None else ( + 1 if len(self.data) < 10 else 30) + self.link = None + + def generate_display(self): + if IMG_THUMBSIZE is not None: + if self.size[1] < self.size[0]: + thumb = self.resize( + (self.size[1]*IMG_THUMBSIZE//self.size[0], IMG_THUMBSIZE)) + else: + thumb = self.resize( + (IMG_THUMBSIZE, self.size[0]*IMG_THUMBSIZE//self.size[1])) + else: + thumb = self + if self.is_video(): + self.anim = Animation(thumb.data, fps=self.fps) + self.anim.render() + self.display_str = self.anim.anim_str + else: + b = io.BytesIO() + data = thumb.data[0] + if data.shape[-1] == 1: + data = data[..., 0] + PIL.Image.fromarray(data).save(b, "PNG") + self.display_str = b.getvalue() + return self.display_str + + def resize(self, size, *args, **kwargs): + if size is None: + return self + use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False) + + # Backward compatibility + resample = kwargs.pop('filter', PIL.Image.BICUBIC) + resample = kwargs.pop('resample', resample) + + if isinstance(size, int): + if use_small_edge_when_int: + h, w = self.data.shape[1:3] + aspect_ratio = h / w + size = (max(size, int(size * aspect_ratio)), + max(size, int(size / aspect_ratio))) + else: + h, w = self.data.shape[1:3] + aspect_ratio = h / w + size = (min(size, int(size * aspect_ratio)), + min(size, int(size / aspect_ratio))) + + if self.size == size[::-1]: + return self + return stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self]) + + def pad(self, padding, *args, **kwargs): + return IImage(TF.pad(self.torch(0), padding=padding, *args, **kwargs), 0) + + def padx(self, multiplier, *args, **kwargs): + size = np.array(self.size) + padding = np.concatenate( + [[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size]) + return self.pad(list(padding), *args, **kwargs) + + def pad2wh(self, w=0, h=0, **kwargs): + cw, ch = self.size + return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs) + + def pad2square(self, *args, **kwargs): + if self.size[0] > self.size[1]: + dx = self.size[0] - self.size[1] + return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs) + elif self.size[0] < self.size[1]: + dx = self.size[1] - self.size[0] + return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs) + return self + + def crop2square(self, *args, **kwargs): + if self.size[0] > self.size[1]: + dx = self.size[0] - self.size[1] + return self.crop([dx//2, 0, self.size[1], self.size[1]], *args, **kwargs) + elif self.size[0] < self.size[1]: + dx = self.size[1] - self.size[0] + return self.crop([0, dx//2, self.size[0], self.size[0]], *args, **kwargs) + return self + + def alpha(self): + return IImage(self.data[..., -1, None], fps=self.fps) + + def rgb(self): + return IImage(self.pil().convert('RGB'), fps=self.fps) + + def png(self): + return IImage(np.concatenate([self.data, 255 * np.ones_like(self.data)[..., :1]], -1)) + + def grid(self, nrows=None, ncols=None): + if nrows is not None: + ncols = math.ceil(self.data.shape[0] / nrows) + elif ncols is not None: + nrows = math.ceil(self.data.shape[0] / ncols) + else: + warnings.warn( + "No dimensions specified, creating a grid with 5 columns (default)") + ncols = 5 + nrows = math.ceil(self.data.shape[0] / ncols) + + pad = nrows * ncols - self.data.shape[0] + data = np.pad(self.data, ((0, pad), (0, 0), (0, 0), (0, 0))) + rows = [np.concatenate(x, 1, dtype=np.uint8) + for x in np.array_split(data, nrows)] + return IImage(np.concatenate(rows, 0, dtype=np.uint8)[None]) + + def hstack(self): + return IImage(np.concatenate(self.data, 1, dtype=np.uint8)[None]) + + def vstack(self): + return IImage(np.concatenate(self.data, 0, dtype=np.uint8)[None]) + + def vsplit(self, number_of_splits): + return IImage(np.concatenate(np.split(self.data, number_of_splits, 1))) + + def hsplit(self, number_of_splits): + return IImage(np.concatenate(np.split(self.data, number_of_splits, 2))) + + def heatmap(self, resize=None, cmap=cv2.COLORMAP_JET): + data = np.stack([cv2.cvtColor(cv2.applyColorMap( + x, cmap), cv2.COLOR_BGR2RGB) for x in self.data]) + return IImage(data).resize(resize, use_small_edge_when_int=True) + + def display(self): + try: + display(self) + except: + print("No display") + return self + + def dilate(self, iterations=1, *args, **kwargs): + if iterations == 0: + return IImage(self.data) + return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8)) + + def erode(self, iterations=1, *args, **kwargs): + return IImage((binary_erosion(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8)) + + def hull(self): + convex_hulls = [] + for frame in self.data: + contours, hierarchy = cv2.findContours( + frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + contours = [x.astype(np.int32) for x in contours] + mask_contours = [cv2.convexHull(np.concatenate(contours))] + canvas = np.zeros(self.data[0].shape, np.uint8) + convex_hull = cv2.drawContours( + canvas, mask_contours, -1, (255, 0, 0), -1) + convex_hulls.append(convex_hull) + return IImage(np.array(convex_hulls)) + + def is_video(self): + return self.data.shape[0] > 1 + + def __getitem__(self, idx): + return IImage(self.data[None, idx], fps=self.fps) + # if self.is_video(): return IImage(self.data[idx], fps = self.fps) + # return self + + def _repr_png_(self): + if self.is_video(): + return None + if self.display_str is None: + self.generate_display() + return self.display_str + + def _repr_html_(self): + if not self.is_video(): + return None + if self.display_str is None: + self.generate_display() + return self.display_str + + def save(self, path): + _, ext = os.path.splitext(path) + if self.is_video(): + # if ext in ['.jpg', '.png']: + if self.display_str is None: + self.generate_display() + if ext == ".apng": + self.anim.anim_obj.save(path, writer="pillow") + else: + self.anim.anim_obj.save(path) + else: + data = self.data if self.data.ndim == 3 else self.data[0] + if data.shape[-1] == 1: + data = data[:, :, 0] + PIL.Image.fromarray(data).save(path) + return self + + def write(self, text, center=(0, 25), font_scale=0.8, color=(255, 255, 255), thickness=2): + if not isinstance(text, list): + text = [text for _ in self.data] + data = np.stack([cv2.putText(x.copy(), t, center, cv2.FONT_HERSHEY_COMPLEX, + font_scale, color, thickness) for x, t in zip(self.data, text)]) + return IImage(data) + + def append_text(self, text, padding, font_scale=0.8, color=(255, 255, 255), thickness=2, scale_factor=0.9, center=(0, 0), fill=0): + + assert np.count_nonzero(padding) == 1 + axis_padding = np.nonzero(padding)[0][0] + scale_padding = padding[axis_padding] + + y_0 = 0 + x_0 = 0 + if axis_padding == 0: + width = scale_padding + y_max = self.shape[1] + elif axis_padding == 1: + width = self.shape[2] + y_max = scale_padding + elif axis_padding == 2: + x_0 = self.shape[2] + width = scale_padding + y_max = self.shape[1] + elif axis_padding == 3: + width = self.shape[2] + y_0 = self.shape[1] + y_max = self.shape[1]+scale_padding + + width -= center[0] + x_0 += center[0] + y_0 += center[1] + + self = self.pad(padding, fill=fill) + + def wrap_text(text, width, _font_scale): + allowed_seperator = ' |-|_|/|\n' + words = re.split(allowed_seperator, text) + # words = text.split() + lines = [] + current_line = words[0] + sep_list = [] + start_idx = 0 + for start_word in words[:-1]: + pos = text.find(start_word, start_idx) + pos += len(start_word) + sep_list.append(text[pos]) + start_idx = pos+1 + + for word, separator in zip(words[1:], sep_list): + if cv2.getTextSize(current_line + separator + word, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: + current_line += separator + word + else: + if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: + lines.append(current_line) + current_line = word + else: + return [] + + if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: + lines.append(current_line) + else: + return [] + return lines + + def wrap_text_and_scale(text, width, _font_scale, y_0, y_max): + height = y_max+1 + while height > y_max: + text_lines = wrap_text(text, width, _font_scale) + if len(text) > 0 and len(text_lines) == 0: + + height = y_max+1 + else: + line_height = cv2.getTextSize( + text_lines[0], cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][1] + height = line_height * len(text_lines) + y_0 + + # scale font if out of frame + if height > y_max: + _font_scale = _font_scale * scale_factor + + return text_lines, line_height, _font_scale + + result = [] + if not isinstance(text, list): + text = [text for _ in self.data] + else: + assert len(text) == len(self.data) + + for x, t in zip(self.data, text): + x = x.copy() + text_lines, line_height, _font_scale = wrap_text_and_scale( + t, width, font_scale, y_0, y_max) + y = line_height + for line in text_lines: + x = cv2.putText( + x, line, (x_0, y_0+y), cv2.FONT_HERSHEY_COMPLEX, _font_scale, color, thickness) + y += line_height + result.append(x) + data = np.stack(result) + + return IImage(data) + + # ========== OPERATORS ============= + + def __or__(self, other): + # TODO: fix for variable sizes + return IImage(np.concatenate([self.data, other.data], 2)) + + def __truediv__(self, other): + # TODO: fix for variable sizes + return IImage(np.concatenate([self.data, other.data], 1)) + + def __and__(self, other): + return IImage(np.concatenate([self.data, other.data], 0)) + + def __add__(self, other): + return IImage(0.5 * self.data + 0.5 * other.data) + + def __mul__(self, other): + if isinstance(other, IImage): + return IImage(self.data / 255. * other.data) + return IImage(self.data * other / 255.) + + def __xor__(self, other): + return IImage(0.5 * self.data + 0.5 * other.data + 0.5 * self.data * (other.data.sum(-1, keepdims=True) == 0)) + + def __invert__(self): + return IImage(255 - self.data) + __rmul__ = __mul__ + + def bbox(self): + return [cv2.boundingRect(x) for x in self.data] + + def fill_bbox(self, bbox_list, fill=255): + data = self.data.copy() + for bbox in bbox_list: + x, y, w, h = bbox + data[:, y:y+h, x:x+w, :] = fill + return IImage(data) + + def crop(self, bbox): + assert len(bbox) in [2, 4] + if len(bbox) == 2: + x, y = 0, 0 + w, h = bbox + elif len(bbox) == 4: + x, y, w, h = bbox + return IImage(self.data[:, y:y+h, x:x+w, :]) + +def stack(images, axis = 0): + return IImage(np.concatenate([x.data for x in images], axis)) + +class Animation: + JS = 0 + HTML = 1 + ANIMATION_MODE = HTML + def __init__(self, frames, fps = 30): + """_summary_ + + Args: + frames (np.ndarray): _description_ + """ + self.frames = frames + self.fps = fps + self.anim_obj = None + self.anim_str = None + def render(self): + size = (self.frames.shape[2],self.frames.shape[1]) + self.fig = plt.figure(figsize = size, dpi = 1) + plt.axis('off') + img = plt.imshow(self.frames[0], cmap = 'gray') + self.fig.subplots_adjust(0,0,1,1) + self.anim_obj = animation.FuncAnimation( + self.fig, + lambda i: img.set_data(self.frames[i,:,:,:]), + frames=self.frames.shape[0], + interval = 1000 / self.fps + ) + plt.close() + if Animation.ANIMATION_MODE == Animation.HTML: + self.anim_str = self.anim_obj.to_html5_video(embed_limit=1000.0) + elif Animation.ANIMATION_MODE == Animation.JS: + self.anim_str = self.anim_obj.to_jshtml() + return self.anim_obj + def _repr_html_(self): + if self.anim_obj is None: self.render() + return self.anim_str \ No newline at end of file diff --git a/src/videogen_hub/pipelines/streamingt2v/utils/image_converter.py b/src/videogen_hub/pipelines/streamingt2v/utils/image_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..6891da9ed39bacea8699599f76037727e07a5156 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/utils/image_converter.py @@ -0,0 +1,45 @@ +import cv2 +import numpy as np +from albumentations.augmentations.geometric import functional as F +from albumentations.core.transforms_interface import DualTransform + +__all__ = ["ProportionalMinScale"] + + +class ProportionalMinScale(DualTransform): + + def __init__( + self, + width: int, + height: int, + interpolation: int = cv2.INTER_LINEAR, + always_apply: bool = False, + p: float = 1, + ): + super(ProportionalMinScale, self).__init__(always_apply, p) + self.width = width + self.height = height + + def apply( + self, img: np.ndarray, width: int = 256, height: int = 256, interpolation: int = cv2.INTER_LINEAR, **params): + h_img, w_img, _ = img.shape + + min_side = np.min([h_img, w_img]) + + if (height/h_img)*w_img >= width: + if h_img == min_side: + return F.smallest_max_size(img, max_size=height, interpolation=interpolation) + else: + return F.longest_max_size(img, max_size=height, interpolation=interpolation) + if (width/w_img)*h_img >= height: + if w_img == min_side: + return F.smallest_max_size(img, max_size=width, interpolation=interpolation) + else: + return F.longest_max_size(img, max_size=width, interpolation=interpolation) + return F.longest_max_size(img, max_size=width, interpolation=interpolation) + + def get_params(self): + return {"width": self.width, "height": self.height} + + def get_transform_init_args_names(self): + return ("width", "height", "intepolation") diff --git a/src/videogen_hub/pipelines/streamingt2v/utils/object_loader.py b/src/videogen_hub/pipelines/streamingt2v/utils/object_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..a121cea14154c4e61c39e6bd961737e735f75879 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/utils/object_loader.py @@ -0,0 +1,26 @@ +import importlib +from functools import partialmethod + + +def instantiate_object(cls_path: str, *args, **kwargs): + class_ = get_class(cls_path, *args, **kwargs) + obj = class_() + return obj + + +def get_class(cls_path: str, *args, **kwargs): + module_name = ".".join(cls_path.split(".")[:-1]) + module = importlib.import_module(module_name) + + class_ = getattr(module, cls_path.split(".")[-1]) + class_.__init__ = partialmethod(class_.__init__, *args, **kwargs) + return class_ + + +if __name__ == "__main__": + + class_ = get_class( + "diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler") + scheduler = class_.from_config("stabilityai/stable-diffusion-2-1", + subfolder="scheduler") + print(scheduler) diff --git a/src/videogen_hub/pipelines/streamingt2v/utils/video_utils.py b/src/videogen_hub/pipelines/streamingt2v/utils/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..799e7fa18ca03d15977e1d60aa4fe458061998fe --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/utils/video_utils.py @@ -0,0 +1,454 @@ +import os +import subprocess +import tempfile +from pathlib import Path +from typing import Union +import shutil + +import cv2 +import imageio +import numpy as np +import torch +import torchvision + +# from decord import VideoReader, cpu +from einops import rearrange, repeat +from .iimage import IImage +from PIL import Image, ImageDraw, ImageFont +from torchvision.utils import save_image + +channel_first = 0 +channel_last = -1 + + +def video_naming(prompt, extension, batch_idx, idx): + prompt_identifier = prompt.replace(" ", "_") + prompt_identifier = prompt_identifier.replace("/", "_") + if len(prompt_identifier) > 40: + prompt_identifier = prompt_identifier[:40] + filename = f"{batch_idx:04d}_{idx:04d}_{prompt_identifier}.{extension}" + return filename + + +def video_naming_chunk(prompt, extension, batch_idx, idx, chunk_idx): + prompt_identifier = prompt.replace(" ", "_") + prompt_identifier = prompt_identifier.replace("/", "_") + if len(prompt_identifier) > 40: + prompt_identifier = prompt_identifier[:40] + filename = f"{batch_idx}_{idx}_{chunk_idx}_{prompt_identifier}.{extension}" + return filename + + +class ResultProcessor: + + def __init__(self, fps: int, n_frames: int, logger=None) -> None: + self.fps = fps + self.logger = logger + self.n_frames = n_frames + + def set_logger(self, logger): + self.logger = logger + + def _create_video( + self, + video, + prompt, + filename: Union[str, Path], + append_video: torch.FloatTensor = None, + input_flow=None, + ): + + if video.ndim == 5: + # can be batches if we provide list of filenames + assert video.shape[0] == 1 + video = video[0] + + if video.shape[0] == 3 and video.shape[1] == self.n_frames: + video = rearrange(video, "C F W H -> F C W H") + assert video.shape[1] == 3, f"Wrong video format. Got {video.shape}" + if isinstance(filename, Path): + filename = filename.as_posix() + # assert video.max() <= 1 and video.min() >= 0 + assert ( + video.max() <= 1.1 and video.min() >= -0.1 + ), f"video has unexpected range: [{video.min()}, {video.max()}]" + vid_obj = IImage(video, vmin=0, vmax=1) + + if prompt is not None: + vid_obj = vid_obj.append_text(prompt, padding=(0, 50, 0, 0)) + + if append_video is not None: + if append_video.ndim == 5: + assert append_video.shape[0] == 1 + append_video = append_video[0] + if append_video.shape[0] < video.shape[0]: + append_video = torch.concat( + [ + append_video, + repeat( + append_video[-1, None], + "F C W H -> (rep F) C W H", + rep=video.shape[0] - append_video.shape[0], + ), + ], + dim=0, + ) + if append_video.ndim == 3 and video.ndim == 4: + append_video = repeat( + append_video, "C W H -> F C W H", F=video.shape[0] + ) + append_video = IImage(append_video, vmin=-1, vmax=1) + if prompt is not None: + append_video = append_video.append_text( + "input_frame", padding=(0, 50, 0, 0) + ) + vid_obj = vid_obj | append_video + vid_obj = vid_obj.setFps(self.fps) + vid_obj.save(filename) + + def _create_prompt_file(self, prompt, filename, video_path: str = None): + filename = Path(filename) + filename = filename.parent / (filename.stem + ".txt") + + with open(filename.as_posix(), "w") as file_writer: + file_writer.write(prompt) + file_writer.write("\n") + if video_path is not None: + file_writer.write(video_path) + else: + file_writer.write(" no_source") + + def log_video( + self, + video: torch.FloatTensor, + prompt: str, + video_id: str, + log_folder: str, + input_flow=None, + video_path_input: str = None, + extension: str = "gif", + prompt_on_vid: bool = True, + append_video: torch.FloatTensor = None, + ): + + with tempfile.TemporaryDirectory() as tmpdirname: + storage_fol = Path(tmpdirname) + filename = f"{video_id}.{extension}".replace("/", "_") + vid_filename = storage_fol / filename + self._create_video( + video, + prompt if prompt_on_vid else None, + vid_filename, + append_video, + input_flow=input_flow, + ) + + prompt_file = storage_fol / f"{video_id}.txt" + self._create_prompt_file(prompt, prompt_file, video_path_input) + + if self.logger.experiment.__class__.__name__ == "_DummyExperiment": + run_fol = ( + Path(self.logger.save_dir) + / self.logger.experiment_id + / self.logger.run_id + / "artifacts" + / log_folder + ) + if not run_fol.exists(): + run_fol.mkdir(parents=True, exist_ok=True) + shutil.copy( + prompt_file.as_posix(), (run_fol / f"{video_id}.txt").as_posix() + ) + shutil.copy(vid_filename, (run_fol / filename).as_posix()) + else: + self.logger.experiment.log_artifact( + self.logger.run_id, prompt_file.as_posix(), log_folder + ) + self.logger.experiment.log_artifact( + self.logger.run_id, vid_filename, log_folder + ) + + def save_to_file( + self, + video: torch.FloatTensor, + prompt: str, + video_filename: Union[str, Path], + input_flow=None, + conditional_video_path: str = None, + prompt_on_vid: bool = True, + conditional_video: torch.FloatTensor = None, + ): + self._create_video( + video, + prompt if prompt_on_vid else None, + video_filename, + conditional_video, + input_flow=input_flow, + ) + self._create_prompt_file(prompt, video_filename, conditional_video_path) + + +def add_text_to_image( + image_array, text, position, font_size, text_color, font_path=None +): + + # Convert the NumPy array to PIL Image + image_pil = Image.fromarray(image_array) + + # Create a drawing object + draw = ImageDraw.Draw(image_pil) + + if font_path is not None: + font = ImageFont.truetype(font_path, font_size) + else: + try: + # Load the font + font = ImageFont.truetype( + "/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", + font_size, + ) + except: + font = ImageFont.load_default() + + # Draw the text on the image + draw.text(position, text, font=font, fill=text_color) + + # Convert the PIL Image back to NumPy array + modified_image_array = np.array(image_pil) + + return modified_image_array + + +def add_text_to_video(video_path, prompt): + + outputs_with_overlay = [] + with open(video_path, "rb") as f: + vr = VideoReader(f, ctx=cpu(0)) + + for i in range(len(vr)): + frame = vr[i] + frame = add_text_to_image( + frame, + prompt, + position=(10, 10), + font_size=15, + text_color=(255, 0, 0), + ) + outputs_with_overlay.append(frame) + outputs = outputs_with_overlay + video_path = video_path.replace("mp4", "gif") + imageio.mimsave(video_path, outputs, duration=100, loop=0) + + +def save_videos_grid( + videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=30, prompt=None +): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + if prompt is not None: + outputs_with_overlay = [] + for frame in outputs: + frame_out = add_text_to_image( + frame, + prompt, + position=(10, 10), + font_size=10, + text_color=(255, 0, 0), + ) + outputs_with_overlay.append(frame_out) + outputs = outputs_with_overlay + imageio.mimsave(path, outputs, duration=round(1 / fps * 1000), loop=0) + # iio.imwrite(path, outputs) + # optimize(path) + + +def set_channel_pos(data, shape_dict, channel_pos): + + assert data.ndim == 5 or data.ndim == 4 + batch_dim = data.shape[0] + frame_dim = shape_dict["frame_dim"] + channel_dim = shape_dict["channel_dim"] + width_dim = shape_dict["width_dim"] + height_dim = shape_dict["height_dim"] + + assert batch_dim != frame_dim + assert channel_dim != frame_dim + assert channel_dim != batch_dim + + video_shape = list(data.shape) + batch_pos = video_shape.index(batch_dim) + + channel_pos = video_shape.index(channel_dim) + w_pos = video_shape.index(width_dim) + h_pos = video_shape.index(height_dim) + if w_pos == h_pos: + video_shape[w_pos] = -1 + h_pos = video_shape.index(height_dim) + pattern_order = {} + pattern_order[batch_pos] = "B" + pattern_order[channel_pos] = "C" + + pattern_order[w_pos] = "W" + pattern_order[h_pos] = "H" + + if data.ndim == 5: + frame_pos = video_shape.index(frame_dim) + pattern_order[frame_pos] = "F" + if channel_pos == channel_first: + pattern = " -> B F C W H" + else: + pattern = " -> B F W H C" + else: + if channel_pos == channel_first: + pattern = " -> B C W H" + else: + pattern = " -> B W H C" + pattern_input = [pattern_order[idx] for idx in range(data.ndim)] + pattern_input = " ".join(pattern_input) + pattern = pattern_input + pattern + data = rearrange(data, pattern) + + +def merge_first_two_dimensions(tensor): + dims = tensor.ndim + letters = [] + for letter_idx in range(dims - 2): + letters.append(chr(letter_idx + 67)) + latters_pattern = " ".join(letters) + tensor = rearrange( + tensor, "A B " + latters_pattern + " -> (A B) " + latters_pattern + ) + # TODO merging first two dimensions might be easier with reshape so no need to create letters + # should be 'tensor.view(*tensor.shape[:2], -1)' + return tensor + + +def apply_spatial_function_to_video_tensor(video, shape, func): + # TODO detect batch, frame, channel, width, and height + + assert video.ndim == 5 + batch_dim = shape["batch_dim"] + frame_dim = shape["frame_dim"] + channel_dim = shape["channel_dim"] + width_dim = shape["width_dim"] + height_dim = shape["height_dim"] + + assert batch_dim != frame_dim + assert channel_dim != frame_dim + assert channel_dim != batch_dim + + video_shape = list(video.shape) + batch_pos = video_shape.index(batch_dim) + frame_pos = video_shape.index(frame_dim) + channel_pos = video_shape.index(channel_dim) + w_pos = video_shape.index(width_dim) + h_pos = video_shape.index(height_dim) + if w_pos == h_pos: + video_shape[w_pos] = -1 + h_pos = video_shape.index(height_dim) + pattern_order = {} + pattern_order[batch_pos] = "B" + pattern_order[channel_pos] = "C" + pattern_order[frame_pos] = "F" + pattern_order[w_pos] = "W" + pattern_order[h_pos] = "H" + pattern_order = sorted(pattern_order.items(), key=lambda x: x[1]) + pattern_order = [x[0] for x in pattern_order] + input_pattern = " ".join(pattern_order) + video = rearrange(video, input_pattern + " -> (B F) C W H") + + video = func(video) + video = rearrange(video, "(B F) C W H -> " + input_pattern, F=frame_dim) + return video + + +def dump_frames(videos, as_mosaik, storage_fol, save_image_kwargs): + + # assume videos is in format B F C H W, range [0,1] + num_frames = videos.shape[1] + num_videos = videos.shape[0] + + if videos.shape[2] != 3 and videos.shape[-1] == 3: + videos = rearrange(videos, "B F W H C -> B F C W H") + + frame_counter = 0 + if not isinstance(storage_fol, Path): + storage_fol = Path(storage_fol) + + for frame_idx in range(num_frames): + print(f" Creating frame {frame_idx}") + batch_frame = videos[:, frame_idx, ...] + + if as_mosaik: + filename = storage_fol / f"frame_{frame_counter:03d}.png" + save_image(batch_frame, fp=filename.as_posix(), **save_image_kwargs) + frame_counter += 1 + else: + for video_idx in range(num_videos): + frame = batch_frame[video_idx] + + filename = storage_fol / f"frame_{frame_counter:03d}.png" + save_image(frame, fp=filename.as_posix(), **save_image_kwargs) + frame_counter += 1 + + +def gif_from_videos(videos): + + assert videos.dim() == 5 + assert videos.min() >= 0 + assert videos.max() <= 1 + gif_file = Path("tmp.gif").absolute() + + with tempfile.TemporaryDirectory() as tmpdirname: + storage_fol = Path(tmpdirname) + nrows = min(4, videos.shape[0]) + dump_frames( + videos=videos, + storage_fol=storage_fol, + as_mosaik=True, + save_image_kwargs={"nrow": nrows}, + ) + cmd = f"ffmpeg -y -f image2 -framerate 4 -i {storage_fol / 'frame_%03d.png'} {gif_file.as_posix()}" + subprocess.check_call( + cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT + ) + return gif_file + + +def add_margin(pil_img, top, right, bottom, left, color): + width, height = pil_img.size + new_width = width + right + left + new_height = height + top + bottom + result = Image.new(pil_img.mode, (new_width, new_height), color) + result.paste(pil_img, (left, top)) + return result + + +def resize_to_fit(image, size): + W, H = size + w, h = image.size + if H / h > W / w: + H_ = int(h * W / w) + W_ = W + else: + W_ = int(w * H / h) + H_ = H + return image.resize((W_, H_)) + + +def pad_to_fit(image, size): + W, H = size + w, h = image.size + pad_h = (H - h) // 2 + pad_w = (W - w) // 2 + return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0)) diff --git a/src/videogen_hub/pipelines/streamingt2v/utils/visualisation.py b/src/videogen_hub/pipelines/streamingt2v/utils/visualisation.py new file mode 100644 index 0000000000000000000000000000000000000000..1a749cb955f27a029645ef2c4f2a2f2e5f199317 --- /dev/null +++ b/src/videogen_hub/pipelines/streamingt2v/utils/visualisation.py @@ -0,0 +1,139 @@ +from collections import defaultdict +import torch +from torchvision.utils import make_grid +from torchvision.transforms import ToPILImage +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.colors import Normalize +from matplotlib import cm + +def pil_concat_v(images): + width = images[0].width + height = sum([image.height for image in images]) + dst = Image.new('RGB', (width, height)) + h = 0 + for image_idx, image in enumerate(images): + dst.paste(image, (0, h)) + h += image.height + return dst + +def pil_concat_h(images): + width = sum([image.width for image in images]) + height = images[0].height + dst = Image.new('RGB', (width, height)) + w = 0 + for image_idx, image in enumerate(images): + dst.paste(image, (w, 0)) + w += image.width + return dst + +def add_label(image, text, fontsize=12): + dst = Image.new('RGB', (image.width, image.height + fontsize*3)) + dst.paste(image, (0, 0)) + draw = ImageDraw.Draw(dst) + font = ImageFont.truetype("../misc/fonts/OpenSans.ttf", fontsize) + draw.text((fontsize, image.height + fontsize),text,(255,255,255),font=font) + return dst + +def pil_concat(images, labels=None, col=8, fontsize=12): + col = min(col, len(images)) + if labels is not None: + labeled_images = [add_label(image, labels[image_idx], fontsize=fontsize) for image_idx, image in enumerate(images)] + else: + labeled_images = images + labeled_images_rows = [] + for row_idx in range(int(np.ceil(len(labeled_images) / col))): + labeled_images_rows.append(pil_concat_h(labeled_images[col*row_idx:col*(row_idx+1)])) + return pil_concat_v(labeled_images_rows) + + +def draw_panoptic_segmentation(model, segmentation, segments_info): + # get the used color map + viridis = cm.get_cmap('viridis') + norm = Normalize(vmin=segmentation.min().item(), vmax=segmentation.max().item()) + fig, ax = plt.subplots() + ax.imshow(segmentation, cmap=viridis, norm=norm) + instances_counter = defaultdict(int) + handles = [] + for segment in segments_info: + segment_id = segment['id'] + segment_label_id = segment['label_id'] + segment_label = model.config.id2label[segment_label_id] + label = f"{segment_label}-{instances_counter[segment_label_id]}" + instances_counter[segment_label_id] += 1 + color = viridis(norm(segment_id)) + handles.append(mpatches.Patch(color=color, label=label)) + ax.legend(handles=handles) + + + +rescale_ = lambda x: (x + 1.) / 2. + +def pil_grid_display(x, mask=None, nrow=4, rescale=True): + if rescale: + x = rescale_(x) + if mask is not None: + mask = mask_to_3_channel(mask) + x = torch.concat([mask, x]) + grid = make_grid(torch.clip(x, 0, 1), nrow=nrow) + return ToPILImage()(grid) + +def pil_display(x, rescale=True): + if rescale: + x = rescale_(x) + image = torch.clip(rescale_(x), 0, 1) + return ToPILImage()(image) + +def mask_to_3_channel(mask): + if mask.dim() == 3: + mask_c_idx = 0 + elif mask.dim() == 4: + mask_c_idx = 1 + else: + raise Exception("mask should be a 3d or 4d tensor") + + if mask.shape[mask_c_idx] == 3: + return mask + elif mask.shape[mask_c_idx] == 1: + sizes = [1] * mask.dim() + sizes[mask_c_idx] = 3 + mask = mask.repeat(*sizes) + else: + raise Exception("mask should have size 1 in channel dim") + return mask + + +def get_first_k_token_head_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False): + n_heads = atts_normed.shape[0] + att_images = [] + for head_idx in range(n_heads): + atts_head = atts_normed[head_idx, :, :k].reshape(h, w, k).movedim(2, 0) + for token_idx in range(k): + att_head_np = atts_head[token_idx].detach().cpu().numpy() + if max_scale: + att_head_np = att_head_np / att_head_np.max() + att_image = Image.fromarray((att_head_np * 255).astype(np.uint8)) + att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST) + att_images.append(att_image) + return pil_concat(att_images, col=k, labels=None) + +def get_first_k_token_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False): + att_images = [] + atts_head = atts_normed.mean(0)[:, :k].reshape(h, w, k).movedim(2, 0) + for token_idx in range(k): + att_head_np = atts_head[token_idx].detach().cpu().numpy() + if max_scale: + att_head_np = att_head_np / att_head_np.max() + att_image = Image.fromarray((att_head_np * 255).astype(np.uint8)) + att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST) + att_images.append(att_image) + return pil_concat(att_images, col=k, labels=None) + +def draw_bbox(image, bbox): + image = image.copy() + left, top, right, bottom = bbox + image_draw = ImageDraw.Draw(image) + image_draw.rectangle(((left, top),(right, bottom)), outline='Red') + return image \ No newline at end of file diff --git a/src/videogen_hub/pipelines/t2v_turbo/README.md b/src/videogen_hub/pipelines/t2v_turbo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3b6b3428ac12d0393e4f68d0598806be23e4c0f6 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/README.md @@ -0,0 +1,101 @@ +# T2V-Turbo: Breaking the Quality Bottleneck of Video Consistency Model with Mixed Reward Feedback + +## Fast and High-Quality Text-to-video Generation 🚀 + +### 4-Step Results + + + + + + + + + +
With the style of low-poly game art, A majestic, white horse gallops gracefully across a moonlit beach.medium shot of Christine, a beautiful 25-year-old brunette resembling Selena Gomez, anxiously looking up as she walks down a New York street, cinematic stylea cartoon pig playing his guitar, Andrew Warhol style
+ + + + + + + + + + + +
a dog wearing vr goggles on a boatPikachu snowboardinga girl floating underwater
+ +### 8-Step Results + + + + + + + + + + +
Mickey Mouse is dancing on white backgroundlight wind, feathers moving, she moves her gaze, 4kfashion portrait shoot of a girl in colorful glasses, a breeze moves her hair
+ + + + + + + + + + + +
With the style of abstract cubism, The flowers swayed in the gentle breeze, releasing their sweet fragrance.impressionist style, a yellow rubber duck floating on the wave on the sunsetA Egyptian tomp hieroglyphics painting ofA regal lion, decked out in a jeweled crown, surveys his kingdom.
+ +## 🏭 Installation + +``` +pip install accelerate transformers diffusers webdataset loralib peft pytorch_lightning open_clip_torch hpsv2 peft wandb av einops packaging omegaconf opencv-python kornia + +pip install flash-attn --no-build-isolation +git clone https://github.com/Dao-AILab/flash-attention.git +cd flash-attention +pip install csrc/fused_dense_lib csrc/layer_norm + +pip install git+https://github.com/iejMac/video2dataset.git + +conda install xformers +``` +## 🛞 Model Checkpoints + +|Model|Resolution|Checkpoints| +|:---------|:---------|:--------| +|T2V-Turbo (VC2)|320x512|[![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/jiachenli-ucsb/T2V-Turbo-VC2/blob/main/unet_lora.pt) +|T2V-Turbo (MS)|256x256|[![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/jiachenli-ucsb/T2V-Turbo-MS/blob/main/unet_lora.pt) + + +## 🚀 Inference + +We provide local demo codes supported with gradio (For MacOS users, need to set the device="mps" in app.py; For Intel GPU users, set device="xpu" in app.py). +1. Download the `unet_lora.pt` of our T2V-Turbo (VC2) [here](https://huggingface.co/jiachenli-ucsb/T2V-Turbo-VC2/blob/main/unet_lora.pt). + +2. Download the model checkpoint of VideoCrafter2 [here](https://huggingface.co/VideoCrafter/VideoCrafter2/blob/main/model.ckpt). + +3. Launch the gradio demo with the following command: +``` +pip install gradio==3.48.0 +python app.py --unet_dir PATH_TO_UNET_LORA.pt --base_model_dir PATH_TO_VideoCrafter2_MODEL_CKPT +``` + +## 🏋️ Training + +To train T2V-Turbo (VC2), run the following command + +``` +bash train_t2v_turbo_vc2.sh +``` + +To train T2V-Turbo (MS), run the following command + +``` +bash train_t2v_turbo_ms.sh +``` diff --git a/src/videogen_hub/pipelines/t2v_turbo/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8707c90510ceb989df34fb83e96795541d764fd9 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/__init__.py @@ -0,0 +1,4 @@ +import sys +sys.path.insert(0, './src/videogen_hub/pipelines/') +sys.path.insert(0, './src/videogen_hub/pipelines/t2v_turbo/') +sys.path.insert(0, './src/videogen_hub/pipelines/t2v_turbo/lvdm/') \ No newline at end of file diff --git a/src/videogen_hub/pipelines/t2v_turbo/inference_ms.py b/src/videogen_hub/pipelines/t2v_turbo/inference_ms.py new file mode 100644 index 0000000000000000000000000000000000000000..82f624a37e3bc7016f713077faf108b0b327f48f --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/inference_ms.py @@ -0,0 +1,149 @@ +# Adapted from https://github.com/luosiallen/latent-consistency-model +from __future__ import annotations + +import os +import random + +import numpy as np + +from .pipeline.t2v_turbo_ms_pipeline import T2VTurboMSPipeline +from .scheduler.t2v_turbo_scheduler import T2VTurboScheduler +from .utils.common_utils import set_torch_2_attn + +try: + import intel_extension_for_pytorch as ipex +except: + pass + +from transformers import CLIPTokenizer, CLIPTextModel +from .model_scope.unet_3d_condition import UNet3DConditionModel + +from .utils.lora import collapse_lora, monkeypatch_remove_lora +from .utils.lora_handler import LoraHandler + +import torch +from diffusers.models import AutoencoderKL + +DESCRIPTION = """# T2V-Turbo 🚀 +We provide T2V-Turbo (MS) distilled from [ModelScopeT2V](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/) with the reward feedback from [HPSv2.1](https://github.com/tgxs002/HPSv2/tree/master) and [ViCLIP](https://huggingface.co/OpenGVLab/ViCLIP). + +You can download the the models from [here](https://huggingface.co/jiachenli-ucsb/T2V-Turbo-MS). Check out our [Project page](https://t2v-turbo.github.io) 😄 +""" +if torch.cuda.is_available(): + DESCRIPTION += "\n

Running on CUDA 😀

" +elif hasattr(torch, "xpu") and torch.xpu.is_available(): + DESCRIPTION += "\n

Running on XPU 🤓

" +else: + DESCRIPTION += "\n

Running on CPU 🥶 This demo does not work on CPU.

" + +MAX_SEED = np.iinfo(np.int32).max +CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" +USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1" + +""" +Operation System Options: + If you are using MacOS, please set the following (device="mps") ; + If you are using Linux & Windows with Nvidia GPU, please set the device="cuda"; + If you are using Linux & Windows with Intel Arc GPU, please set the device="xpu"; +""" +# device = "mps" # MacOS +# device = "xpu" # Intel Arc GPU +device = "cuda" # Linux & Windows + +""" + DTYPE Options: + To reduce GPU memory you can set "DTYPE=torch.float16", + but image quality might be compromised +""" +DTYPE = ( + torch.float16 +) # torch.float16 works as well, but pictures seem to be a bit worse + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + +class T2VTurboMSPipeline1: + def __init__(self, device, unet_dir, base_model_dir): + pretrained_model_path = base_model_dir + tokenizer = CLIPTokenizer.from_pretrained( + pretrained_model_path, subfolder="tokenizer" + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_path, subfolder="text_encoder" + ) + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + teacher_unet = UNet3DConditionModel.from_pretrained( + pretrained_model_path, subfolder="unet" + ) + + time_cond_proj_dim = 256 + unet = UNet3DConditionModel.from_config( + teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim + ) + # load teacher_unet weights into unet + unet.load_state_dict(teacher_unet.state_dict(), strict=False) + del teacher_unet + set_torch_2_attn(unet) + use_unet_lora = True + lora_manager = LoraHandler( + version="cloneofsimo", + use_unet_lora=use_unet_lora, + save_for_webui=True, + ) + lora_manager.add_lora_to_model( + use_unet_lora, + unet, + lora_manager.unet_replace_modules, + lora_path=unet_dir, + dropout=0.1, + r=32, + ) + collapse_lora(unet, lora_manager.unet_replace_modules) + monkeypatch_remove_lora(unet) + unet.eval() + + noise_scheduler = T2VTurboScheduler() + self.pipeline = T2VTurboMSPipeline( + unet=unet, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=noise_scheduler, + ) + self.pipeline.to(device) + + def inference( + self, + prompt: str, + height: int = 320, + width: int = 512, + seed: int = 0, + guidance_scale: float = 7.5, + num_inference_steps: int = 4, + num_frames: int = 16, + fps: int = 16, + randomize_seed: bool = False, + param_dtype="torch.float16" + ): + seed = randomize_seed_fn(seed, randomize_seed) + torch.manual_seed(seed) + self.pipeline.to( + torch_device=device, + torch_dtype=torch.float16 if param_dtype == "torch.float16" else torch.float32, + ) + + result = self.pipeline( + prompt=prompt, + height=height, + width=width, + frames=num_frames, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + num_videos_per_prompt=1, + ) + + return result diff --git a/src/videogen_hub/pipelines/t2v_turbo/inference_vc2.py b/src/videogen_hub/pipelines/t2v_turbo/inference_vc2.py new file mode 100644 index 0000000000000000000000000000000000000000..37ea250d17e6e7526010c9f1fdb6ee9ed0668c18 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/inference_vc2.py @@ -0,0 +1,146 @@ +# Adapted from https://github.com/luosiallen/latent-consistency-model +from __future__ import annotations + +import os +import random +from omegaconf import OmegaConf + +import numpy as np + +try: + import intel_extension_for_pytorch as ipex +except: + pass + +from .utils.lora import collapse_lora, monkeypatch_remove_lora +from .utils.lora_handler import LoraHandler +from .utils.common_utils import load_model_checkpoint +from .utils.utils import instantiate_from_config +from .scheduler.t2v_turbo_scheduler import T2VTurboScheduler +from .pipeline.t2v_turbo_vc2_pipeline import T2VTurboVC2Pipeline + +import torch + +DESCRIPTION = """# T2V-Turbo 🚀 +We provide T2V-Turbo (VC2) distilled from [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2/) with the reward feedback from [HPSv2.1](https://github.com/tgxs002/HPSv2/tree/master) and [InternVid2 Stage 2 Model](https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4). + +You can download the the models from [here](https://huggingface.co/jiachenli-ucsb/T2V-Turbo-VC2). Check out our [Project page](https://t2v-turbo.github.io) 😄 +""" +if torch.cuda.is_available(): + DESCRIPTION += "\n

Running on CUDA 😀

" +elif hasattr(torch, "xpu") and torch.xpu.is_available(): + DESCRIPTION += "\n

Running on XPU 🤓

" +else: + DESCRIPTION += "\n

Running on CPU 🥶 This demo does not work on CPU.

" + +MAX_SEED = np.iinfo(np.int32).max +CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" +USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1" + +""" +Operation System Options: + If you are using MacOS, please set the following (device="mps") ; + If you are using Linux & Windows with Nvidia GPU, please set the device="cuda"; + If you are using Linux & Windows with Intel Arc GPU, please set the device="xpu"; +""" +# device = "mps" # MacOS +# device = "xpu" # Intel Arc GPU +device = "cuda" # Linux & Windows + +""" + DTYPE Options: + To reduce GPU memory you can set "DTYPE=torch.float16", + but image quality might be compromised +""" +DTYPE = ( + torch.float16 +) # torch.float16 works as well, but pictures seem to be a bit worse + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + +class T2VTurboVC2Pipeline1: + def __init__(self, config, merged, device, unet_dir, base_model_dir): + config = OmegaConf.create(config) + model_config = config.pop("model", OmegaConf.create()) + pretrained_t2v = instantiate_from_config(model_config) + + unet_config = model_config["params"]["unet_config"] + unet_config["params"]["time_cond_proj_dim"] = 256 + unet = instantiate_from_config(unet_config) + + if merged: + pretrained_t2v.model.diffusion_model = unet + pretrained_t2v = load_model_checkpoint(pretrained_t2v, base_model_dir) + + else: + pretrained_t2v = load_model_checkpoint(pretrained_t2v, base_model_dir) + + unet.load_state_dict( + pretrained_t2v.model.diffusion_model.state_dict(), strict=False + ) + + use_unet_lora = True + lora_manager = LoraHandler( + version="cloneofsimo", + use_unet_lora=use_unet_lora, + save_for_webui=True, + unet_replace_modules=["UNetModel"], + ) + lora_manager.add_lora_to_model( + use_unet_lora, + unet, + lora_manager.unet_replace_modules, + lora_path=unet_dir, + dropout=0.1, + r=64, + ) + unet.eval() + collapse_lora(unet, lora_manager.unet_replace_modules) + monkeypatch_remove_lora(unet) + + pretrained_t2v.model.diffusion_model = unet + + scheduler = T2VTurboScheduler( + linear_start=model_config["params"]["linear_start"], + linear_end=model_config["params"]["linear_end"], + ) + self.pipeline = T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config) + self.pipeline.to(device) + + def inference( + self, + prompt: str, + height: int = 320, + width: int = 512, + seed: int = 0, + guidance_scale: float = 7.5, + num_inference_steps: int = 4, + num_frames: int = 16, + fps: int = 16, + randomize_seed: bool = False, + param_dtype="torch.float16" + ): + seed = randomize_seed_fn(seed, randomize_seed) + torch.manual_seed(seed) + self.pipeline.to( + torch_device=device, + torch_dtype=torch.float16 if param_dtype == "torch.float16" else torch.float32, + ) + + result = self.pipeline( + prompt=prompt, + height=height, + width=width, + frames=num_frames, + fps=fps, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + num_videos_per_prompt=1, + ) + + return result diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/basics.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/basics.py new file mode 100644 index 0000000000000000000000000000000000000000..bedcc326b019de615e883d0ce6635c9bce378b66 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/basics.py @@ -0,0 +1,102 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +import torch.nn as nn +from videogen_hub.pipelines.t2v_turbo.utils.utils import instantiate_from_config + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def nonlinearity(type="silu"): + if type == "silu": + return nn.SiLU() + elif type == "leaky_relu": + return nn.LeakyReLU() + + +class GroupNormSpecific(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels, num_groups=32): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNormSpecific(num_groups, channels) + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/common.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/common.py new file mode 100644 index 0000000000000000000000000000000000000000..88f0db0d8b668ed5d189c3774e76f28ca0cb83b3 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/common.py @@ -0,0 +1,112 @@ +import math +from inspect import isfunction +import torch +from torch import nn +import torch.distributed as dist + + +def gather_data(data, return_np=True): + """gather data from multiple processes to one list""" + data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] + dist.all_gather(data_list, data) # gather not supported with NCCL + if return_np: + data_list = [data.cpu().numpy() for data in data_list] + return data_list + + +def autocast(f): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast( + enabled=True, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), + ): + return f(*args, **kwargs) + + return do_autocast + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def exists(val): + return val is not None + + +def identity(*args, **kwargs): + return nn.Identity() + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def shape_to_str(x): + shape_str = "x".join([str(x) for x in x.shape]) + return shape_str + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +ckpt = torch.utils.checkpoint.checkpoint + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + return ckpt(func, *inputs) + else: + return func(*inputs) diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/distributions.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..caf57f71f2bd0f51f5515abe197ced6430f31177 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/distributions.py @@ -0,0 +1,103 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self, noise=None): + if noise is None: + noise = torch.randn(self.mean.shape) + + x = self.mean + self.std * noise.to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/ema.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..0e1447b06b710151e769fc820049db54fe132510 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/ema.py @@ -0,0 +1,84 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + ( + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int) + ), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/autoencoder.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..31b3ff2ffd71a1009cd7bdbce1e32ddddb74cd40 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/autoencoder.py @@ -0,0 +1,276 @@ +import os +from contextlib import contextmanager +import torch +import numpy as np +from einops import rearrange +import torch.nn.functional as F +import pytorch_lightning as pl +from videogen_hub.pipelines.t2v_turbo.lvdm.modules.networks.ae_modules import Encoder, Decoder +from videogen_hub.pipelines.t2v_turbo.lvdm.distributions import DiagonalGaussianDistribution +from videogen_hub.pipelines.t2v_turbo.utils.utils import instantiate_from_config + + +class AutoencoderKL(pl.LightningModule): + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + test=False, + logdir=None, + input_dim=4, + test_args=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + self.input_dim = input_dim + self.test = test + self.test_args = test_args + self.logdir = logdir + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + if self.test: + self.init_test() + + def init_test( + self, + ): + self.test = True + save_dir = os.path.join(self.logdir, "test") + if "ckpt" in self.test_args: + ckpt_name = ( + os.path.basename(self.test_args.ckpt).split(".ckpt")[0] + + f"_epoch{self._cur_epoch}" + ) + self.root = os.path.join(save_dir, ckpt_name) + else: + self.root = save_dir + if "test_subdir" in self.test_args: + self.root = os.path.join(save_dir, self.test_args.test_subdir) + + self.root_zs = os.path.join(self.root, "zs") + self.root_dec = os.path.join(self.root, "reconstructions") + self.root_inputs = os.path.join(self.root, "inputs") + os.makedirs(self.root, exist_ok=True) + + if self.test_args.save_z: + os.makedirs(self.root_zs, exist_ok=True) + if self.test_args.save_reconstruction: + os.makedirs(self.root_dec, exist_ok=True) + if self.test_args.save_input: + os.makedirs(self.root_inputs, exist_ok=True) + assert self.test_args is not None + self.test_maximum = getattr(self.test_args, "test_maximum", None) + self.count = 0 + self.eval_metrics = {} + self.decodes = [] + self.save_decode_samples = 2048 + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu") + try: + self._cur_epoch = sd["epoch"] + sd = sd["state_dict"] + except: + self._cur_epoch = "null" + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + # self.load_state_dict(sd, strict=True) + print(f"Restored from {path}") + + def encode(self, x, **kwargs): + + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, **kwargs): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if x.dim() == 5 and self.input_dim == 4: + b, c, t, h, w = x.shape + self.b = b + self.t = t + x = rearrange(x, "b c t h w -> (b t) c h w") + + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + + self.log( + "discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val", + ) + + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val", + ) + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9), + ) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) + ) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/ddpm3d.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/ddpm3d.py new file mode 100644 index 0000000000000000000000000000000000000000..52bfccdfb8b956c0c83e4d25210256cd6dff81f1 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/ddpm3d.py @@ -0,0 +1,967 @@ +""" +wild mixture of +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +from functools import partial +from contextlib import contextmanager +import numpy as np +from tqdm import tqdm +from einops import rearrange, repeat +import logging + +mainlogger = logging.getLogger("mainlogger") +import torch +import torch.nn as nn +from torchvision.utils import make_grid +import pytorch_lightning as pl +from videogen_hub.pipelines.t2v_turbo.utils.utils import instantiate_from_config +from videogen_hub.pipelines.t2v_turbo.lvdm.ema import LitEma +from videogen_hub.pipelines.t2v_turbo.lvdm.distributions import DiagonalGaussianDistribution +from videogen_hub.pipelines.t2v_turbo.lvdm.models.utils_diffusion import make_beta_schedule +from videogen_hub.pipelines.t2v_turbo.lvdm.modules.encoders.ip_resampler import ImageProjModel, Resampler +from videogen_hub.pipelines.t2v_turbo.lvdm.basics import disabled_train +from videogen_hub.pipelines.t2v_turbo.lvdm.common import extract_into_tensor, noise_like, exists, default + + +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor=None, + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + ): + super().__init__() + assert parameterization in [ + "eps", + "x0", + ], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + mainlogger.info( + f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" + ) + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.channels = channels + self.temporal_length = unet_config.params.temporal_length + self.image_size = image_size + if isinstance(self.image_size, int): + self.image_size = [self.image_size, self.image_size] + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + mainlogger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt( + ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet + ) + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + mainlogger.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + mainlogger.info(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + mainlogger.info("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + mainlogger.info( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + mainlogger.info(f"Missing Keys: {missing}") + if len(unexpected) > 0: + mainlogger.info(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="Sampling t", + total=self.num_timesteps, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates, + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + * extract_into_tensor(self.scale_arr, t, x_start.shape) + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def get_input(self, batch, k): + x = batch[k] + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample( + batch_size=N, return_intermediates=True + ) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__( + self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="caption", + cond_stage_trainable=False, + cond_stage_forward=None, + conditioning_key=None, + uncond_prob=0.2, + uncond_type="empty_seq", + scale_factor=1.0, + scale_by_std=False, + encoder_type="2d", + only_model=False, + use_scale=False, + scale_a=1, + scale_b=0.3, + mid_step=400, + fix_scale_bug=False, + *args, + **kwargs, + ): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs["timesteps"] + # for backwards compatibility after implementation of DiffusionWrapper + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + conditioning_key = default(conditioning_key, "crossattn") + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + + # scale factor + self.use_scale = use_scale + if self.use_scale: + self.scale_a = scale_a + self.scale_b = scale_b + if fix_scale_bug: + scale_step = self.num_timesteps - mid_step + else: # bug + scale_step = self.num_timesteps + + scale_arr1 = np.linspace(scale_a, scale_b, mid_step) + scale_arr2 = np.full(scale_step, scale_b) + scale_arr = np.concatenate((scale_arr1, scale_arr2)) + scale_arr_prev = np.append(scale_a, scale_arr[:-1]) + to_torch = partial(torch.tensor, dtype=torch.float32) + self.register_buffer("scale_arr", to_torch(scale_arr)) + + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer("scale_factor", torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.first_stage_config = first_stage_config + self.cond_stage_config = cond_stage_config + self.clip_denoised = False + + self.cond_stage_forward = cond_stage_forward + self.encoder_type = encoder_type + assert encoder_type in ["2d", "3d"] + self.uncond_prob = uncond_prob + self.classifier_free_guidance = True if uncond_prob > 0 else False + assert uncond_type in ["zero_embed", "empty_seq"] + self.uncond_type = uncond_type + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model) + self.restarted_from_ckpt = True + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + if self.use_scale: + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + * extract_into_tensor(self.scale_arr, t, x_start.shape) + + extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_start.shape + ) + * noise + ) + else: + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + + extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_start.shape + ) + * noise + ) + + def _freeze_model(self): + for name, para in self.model.diffusion_model.named_parameters(): + para.requires_grad = False + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + model = instantiate_from_config(config) + self.cond_stage_model = model + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, "encode") and callable( + self.cond_stage_model.encode + ): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def get_first_stage_encoding(self, encoder_posterior, noise=None): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample(noise=noise) + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + @torch.no_grad() + def encode_first_stage(self, x): + if self.encoder_type == "2d" and x.dim() == 5: + b, _, t, _, _ = x.shape + x = rearrange(x, "b c t h w -> (b t) c h w") + reshape_back = True + else: + reshape_back = False + + encoder_posterior = self.first_stage_model.encode(x) + results = self.get_first_stage_encoding(encoder_posterior).detach() + + if reshape_back: + results = rearrange(results, "(b t) c h w -> b c t h w", b=b, t=t) + + return results + + @torch.no_grad() + def encode_first_stage_2DAE(self, x): + + b, _, t, _, _ = x.shape + results = torch.cat( + [ + self.get_first_stage_encoding(self.first_stage_model.encode(x[:, :, i])) + .detach() + .unsqueeze(2) + for i in range(t) + ], + dim=2, + ) + + return results + + def decode_core(self, z, **kwargs): + if self.encoder_type == "2d" and z.dim() == 5: + b, _, t, _, _ = z.shape + z = rearrange(z, "b c t h w -> (b t) c h w") + reshape_back = True + else: + reshape_back = False + + z = 1.0 / self.scale_factor * z + + results = self.first_stage_model.decode(z, **kwargs) + + if reshape_back: + results = rearrange(results, "(b t) c h w -> b c t h w", b=b, t=t) + return results + + @torch.no_grad() + def decode_first_stage(self, z, **kwargs): + return self.decode_core(z, **kwargs) + + def apply_model(self, x_noisy, t, cond, **kwargs): + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = ( + "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" + ) + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond, **kwargs) + + if isinstance(x_recon, tuple): + return x_recon[0] + else: + return x_recon + + def _get_denoise_row_from_list(self, samples, desc=""): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device))) + n_log_timesteps = len(denoise_row) + + denoise_row = torch.stack(denoise_row) # n_log_timesteps, b, C, H, W + + if denoise_row.dim() == 5: + # img, num_imgs= n_log_timesteps * bs, grid_size=[bs,n_log_timesteps] + denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") + denoise_grid = make_grid(denoise_grid, nrow=n_log_timesteps) + elif denoise_row.dim() == 6: + # video, grid_size=[n_log_timesteps*bs, t] + video_length = denoise_row.shape[3] + denoise_grid = rearrange(denoise_row, "n b c t h w -> b n c t h w") + denoise_grid = rearrange(denoise_grid, "b n c t h w -> (b n) c t h w") + denoise_grid = rearrange(denoise_grid, "n c t h w -> (n t) c h w") + denoise_grid = make_grid(denoise_grid, nrow=video_length) + else: + raise ValueError + + return denoise_grid + + @torch.no_grad() + def decode_first_stage_2DAE(self, z, **kwargs): + + b, _, t, _, _ = z.shape + z = 1.0 / self.scale_factor * z + results = torch.cat( + [ + self.first_stage_model.decode(z[:, :, i], **kwargs).unsqueeze(2) + for i in range(t) + ], + dim=2, + ) + + return results + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + **kwargs, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, **kwargs) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + + if return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + **kwargs, + ): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + **kwargs, + ) + if return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + **kwargs, + ): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + # sample an initial noise + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + if start_T is not None: + timesteps = min(timesteps, start_T) + + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample( + img, cond, ts, clip_denoised=self.clip_denoised, **kwargs + ) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + +class LatentVisualDiffusion(LatentDiffusion): + def __init__( + self, cond_img_config, finegrained=False, random_cond=False, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.random_cond = random_cond + self.instantiate_img_embedder(cond_img_config, freeze=True) + num_tokens = 16 if finegrained else 4 + self.image_proj_model = self.init_projector( + use_finegrained=finegrained, + num_tokens=num_tokens, + input_dim=1024, + cross_attention_dim=1024, + dim=1280, + ) + + def instantiate_img_embedder(self, config, freeze=True): + embedder = instantiate_from_config(config) + if freeze: + self.embedder = embedder.eval() + self.embedder.train = disabled_train + for param in self.embedder.parameters(): + param.requires_grad = False + + def init_projector( + self, use_finegrained, num_tokens, input_dim, cross_attention_dim, dim + ): + if not use_finegrained: + image_proj_model = ImageProjModel( + clip_extra_context_tokens=num_tokens, + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=input_dim, + ) + else: + image_proj_model = Resampler( + dim=input_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=num_tokens, + embedding_dim=dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + return image_proj_model + + ## Never delete this func: it is used in log_images() and inference stage + def get_image_embeds(self, batch_imgs): + ## img: b c h w + img_token = self.embedder(batch_imgs) + img_emb = self.image_proj_model(img_token) + return img_emb + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + + def forward( + self, + x, + t, + c_concat: list = None, + c_crossattn: list = None, + c_adm=None, + s=None, + mask=None, + **kwargs, + ): + # temporal_context = fps is foNone + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == "concat": + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t, **kwargs) + elif self.conditioning_key == "crossattn": + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, **kwargs) + elif self.conditioning_key == "hybrid": + ## it is just right [b,c,t,h,w]: concatenate in channel dim + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == "resblockcond": + cc = c_crossattn[0] + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == "adm": + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + elif self.conditioning_key == "hybrid-adm": + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == "hybrid-time": + assert s is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, s=s) + elif self.conditioning_key == "concat-time-mask": + # assert s is not None + # mainlogger.info('x & mask:',x.shape,c_concat[0].shape) + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t, context=None, s=s, mask=mask) + elif self.conditioning_key == "concat-adm-mask": + # assert s is not None + # mainlogger.info('x & mask:',x.shape,c_concat[0].shape) + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + out = self.diffusion_model(xc, t, context=None, y=s, mask=mask) + elif self.conditioning_key == "hybrid-adm-mask": + cc = torch.cat(c_crossattn, 1) + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + out = self.diffusion_model(xc, t, context=cc, y=s, mask=mask) + elif ( + self.conditioning_key == "hybrid-time-adm" + ): # adm means y, e.g., class index + # assert s is not None + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, s=s, y=c_adm) + else: + raise NotImplementedError() + + return out diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/samplers/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/samplers/ddim.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/samplers/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0b6b6f397295c30ca0ef6f8a2d2647c84cc4a7 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/samplers/ddim.py @@ -0,0 +1,493 @@ +import numpy as np +from tqdm import tqdm +import torch +from videogen_hub.pipelines.t2v_turbo.lvdm.models.utils_diffusion import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, +) +from videogen_hub.pipelines.t2v_turbo.lvdm.common import noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.counter = 0 + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + self.use_scale = self.model.use_scale + print("DDIM scale", self.use_scale) + + if self.use_scale: + self.register_buffer("scale_arr", to_torch(self.model.scale_arr)) + ddim_scale_arr = self.scale_arr.cpu()[self.ddim_timesteps] + self.register_buffer("ddim_scale_arr", ddim_scale_arr) + ddim_scale_arr = np.asarray( + [self.scale_arr.cpu()[0]] + + self.scale_arr.cpu()[self.ddim_timesteps[:-1]].tolist() + ) + self.register_buffer("ddim_scale_arr_prev", ddim_scale_arr) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + schedule_verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + + # check condition bs + if conditioning is not None: + if isinstance(conditioning, dict): + try: + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + except: + cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] + + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=schedule_verbose) + + # make shape + if len(shape) == 3: + C, H, W = shape + size = (batch_size, C, H, W) + elif len(shape) == 4: + C, T, H, W = shape + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + verbose=verbose, + **kwargs, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + verbose=True, + cond_tau=1.0, + target_size=None, + start_timesteps=None, + **kwargs, + ): + device = self.model.betas.device + print("ddim device", device) + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + if verbose: + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + else: + iterator = time_range + + init_x0 = False + clean_cond = kwargs.pop("clean_cond", False) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + if start_timesteps is not None: + assert x0 is not None + if step > start_timesteps * time_range[0]: + continue + elif not init_x0: + img = self.model.q_sample(x0, ts) + init_x0 = True + + # use mask to blend noised original latent (img_orig) & new sampled latent (img) + if mask is not None: + assert x0 is not None + if clean_cond: + img_orig = x0 + else: + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = ( + img_orig * mask + (1.0 - mask) * img + ) # keep original & modify use img + + index_clip = int((1 - cond_tau) * total_steps) + if index <= index_clip and target_size is not None: + target_size_ = [ + target_size[0], + target_size[1] // 8, + target_size[2] // 8, + ] + img = torch.nn.functional.interpolate( + img, + size=target_size_, + mode="nearest", + ) + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + x0=x0, + **kwargs, + ) + + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + uc_type=None, + conditional_guidance_scale_temporal=None, + **kwargs, + ): + b, *_, device = *x.shape, x.device + if x.dim() == 5: + is_video = True + else: + is_video = False + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + e_t = self.model.apply_model(x, t, c, **kwargs) # unet denoiser + else: + # with unconditional condition + if isinstance(c, torch.Tensor): + e_t = self.model.apply_model(x, t, c, **kwargs) + e_t_uncond = self.model.apply_model( + x, t, unconditional_conditioning, **kwargs + ) + elif isinstance(c, dict): + e_t = self.model.apply_model(x, t, c, **kwargs) + e_t_uncond = self.model.apply_model( + x, t, unconditional_conditioning, **kwargs + ) + else: + raise NotImplementedError + # text cfg + if uc_type is None: + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + else: + if uc_type == "cfg_original": + e_t = e_t + unconditional_guidance_scale * (e_t - e_t_uncond) + elif uc_type == "cfg_ours": + e_t = e_t + unconditional_guidance_scale * (e_t_uncond - e_t) + else: + raise NotImplementedError + # temporal guidance + if conditional_guidance_scale_temporal is not None: + e_t_temporal = self.model.apply_model(x, t, c, **kwargs) + e_t_image = self.model.apply_model( + x, t, c, no_temporal_attn=True, **kwargs + ) + e_t = e_t + conditional_guidance_scale_temporal * ( + e_t_temporal - e_t_image + ) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + + if is_video: + size = (b, 1, 1, 1, 1) + else: + size = (b, 1, 1, 1) + a_t = torch.full(size, alphas[index], device=device) + a_prev = torch.full(size, alphas_prev[index], device=device) + sigma_t = torch.full(size, sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + size, sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + if self.use_scale: + scale_arr = ( + self.model.scale_arr if use_original_steps else self.ddim_scale_arr + ) + scale_t = torch.full(size, scale_arr[index], device=device) + scale_arr_prev = ( + self.model.scale_arr_prev + if use_original_steps + else self.ddim_scale_arr_prev + ) + scale_t_prev = torch.full(size, scale_arr_prev[index], device=device) + pred_x0 /= scale_t + x_prev = a_prev.sqrt() * scale_t_prev * pred_x0 + dir_xt + noise + else: + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + + def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + ): + + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return x_dec diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/utils_diffusion.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/utils_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..eb48f26f60d0340507d8c3d443730a762e5db0db --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/models/utils_diffusion.py @@ -0,0 +1,130 @@ +import math +import numpy as np +from einops import repeat +import torch +import torch.nn.functional as F + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}') + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/attention.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2870f9b03aaf05957c0b44b4cd6f9e797b67491f --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/attention.py @@ -0,0 +1,584 @@ +from functools import partial +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange, repeat + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False +from videogen_hub.pipelines.t2v_turbo.lvdm.common import ( + checkpoint, + exists, + default, +) +from videogen_hub.pipelines.t2v_turbo.lvdm.basics import ( + zero_module, +) + + +class RelativePosition(nn.Module): + """https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py""" + + def __init__(self, num_units, max_relative_position): + super().__init__() + self.num_units = num_units + self.max_relative_position = max_relative_position + self.embeddings_table = nn.Parameter( + torch.Tensor(max_relative_position * 2 + 1, num_units) + ) + nn.init.xavier_uniform_(self.embeddings_table) + + def forward(self, length_q, length_k): + device = self.embeddings_table.device + range_vec_q = torch.arange(length_q, device=device) + range_vec_k = torch.arange(length_k, device=device) + distance_mat = range_vec_k[None, :] - range_vec_q[:, None] + distance_mat_clipped = torch.clamp( + distance_mat, -self.max_relative_position, self.max_relative_position + ) + final_mat = distance_mat_clipped + self.max_relative_position + final_mat = final_mat.long() + embeddings = self.embeddings_table[final_mat] + return embeddings + + +class CrossAttention(nn.Module): + + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + relative_position=False, + temporal_length=None, + img_cross_attention=False, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + self.dim_head = dim_head + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + self.image_cross_attention_scale = 1.0 + self.text_context_len = 77 + self.img_cross_attention = img_cross_attention + if self.img_cross_attention: + self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) + + self.relative_position = relative_position + if self.relative_position: + assert temporal_length is not None + self.relative_position_k = RelativePosition( + num_units=dim_head, max_relative_position=temporal_length + ) + self.relative_position_v = RelativePosition( + num_units=dim_head, max_relative_position=temporal_length + ) + else: + ## only used for spatial attention, while NOT for temporal attention + if XFORMERS_IS_AVAILBLE and temporal_length is None: + self.forward = self.efficient_forward + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + ## considering image token additionally + if context is not None and self.img_cross_attention: + context, context_img = ( + context[:, : self.text_context_len, :], + context[:, self.text_context_len :, :], + ) + k = self.to_k(context) + v = self.to_v(context) + k_ip = self.to_k_ip(context_img) + v_ip = self.to_v_ip(context_img) + else: + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale + if self.relative_position: + len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1] + k2 = self.relative_position_k(len_q, len_k) + sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale # TODO check + sim += sim2 + del k + + if exists(mask): + ## feasible for causal attention mask only + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b i j -> (b h) i j", h=h) + sim.masked_fill_(~(mask > 0.5), max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + out = torch.einsum("b i j, b j d -> b i d", sim, v) + if self.relative_position: + v2 = self.relative_position_v(len_q, len_v) + out2 = einsum("b t s, t s d -> b t d", sim, v2) # TODO check + out += out2 + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + + ## considering image token additionally + if context is not None and self.img_cross_attention: + k_ip, v_ip = map( + lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (k_ip, v_ip) + ) + sim_ip = torch.einsum("b i d, b j d -> b i j", q, k_ip) * self.scale + del k_ip + sim_ip = sim_ip.softmax(dim=-1) + out_ip = torch.einsum("b i j, b j d -> b i d", sim_ip, v_ip) + out_ip = rearrange(out_ip, "(b h) n d -> b n (h d)", h=h) + out = out + self.image_cross_attention_scale * out_ip + del q + + return self.to_out(out) + + def efficient_forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + + ## considering image token additionally + if context is not None and self.img_cross_attention: + context, context_img = ( + context[:, : self.text_context_len, :], + context[:, self.text_context_len :, :], + ) + k = self.to_k(context) + v = self.to_v(context) + k_ip = self.to_k_ip(context_img) + v_ip = self.to_v_ip(context_img) + else: + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) + + ## considering image token additionally + if context is not None and self.img_cross_attention: + k_ip, v_ip = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (k_ip, v_ip), + ) + out_ip = xformers.ops.memory_efficient_attention( + q, k_ip, v_ip, attn_bias=None, op=None + ) + out_ip = ( + out_ip.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + if context is not None and self.img_cross_attention: + out = out + self.image_cross_attention_scale * out_ip + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attention_cls=None, + img_cross_attention=False, + ): + super().__init__() + attn_cls = CrossAttention if attention_cls is None else attention_cls + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + img_cross_attention=img_cross_attention, + ) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, mask=None): + ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments + input_tuple = ( + x, + ) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments + if context is not None: + input_tuple = (x, context) + if mask is not None: + forward_mask = partial(self._forward, mask=mask) + return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint) + if context is not None and mask is not None: + input_tuple = (x, context, mask) + return checkpoint( + self._forward, input_tuple, self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None, mask=None): + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None, + mask=mask, + ) + + x + ) + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data in spatial axis. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + use_checkpoint=True, + disable_self_attn=False, + use_linear=False, + img_cross_attention=False, + ): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + img_cross_attention=img_cross_attention, + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data in temporal axis. + First, reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + use_checkpoint=True, + use_linear=False, + only_self_att=True, + causal_attention=False, + relative_position=False, + temporal_length=None, + ): + super().__init__() + self.only_self_att = only_self_att + self.relative_position = relative_position + self.causal_attention = causal_attention + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.proj_in = nn.Conv1d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + if not use_linear: + self.proj_in = nn.Conv1d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + if relative_position: + assert temporal_length is not None + attention_cls = partial( + CrossAttention, relative_position=True, temporal_length=temporal_length + ) + else: + attention_cls = None + if self.causal_attention: + assert temporal_length is not None + self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length])) + + if self.only_self_att: + context_dim = None + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + attention_cls=attention_cls, + checkpoint=use_checkpoint, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + b, c, t, h, w = x.shape + x_in = x + x = self.norm(x) + x = rearrange(x, "b c t h w -> (b h w) c t").contiguous() + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "bhw c t -> bhw t c").contiguous() + if self.use_linear: + x = self.proj_in(x) + + if self.causal_attention: + mask = self.mask.to(x.device) + mask = repeat(mask, "l i j -> (l bhw) i j", bhw=b * h * w) + else: + mask = None + + if self.only_self_att: + ## note: if no context is given, cross-attention defaults to self-attention + for i, block in enumerate(self.transformer_blocks): + x = block(x, mask=mask) + x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous() + else: + x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous() + context = rearrange(context, "(b t) l con -> b t l con", t=t).contiguous() + for i, block in enumerate(self.transformer_blocks): + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_j = repeat( + context[j], "t l con -> (t r) l con", r=(h * w) // t, t=t + ).contiguous() + ## note: causal mask will not applied in cross-attention case + x[j] = block(x[j], context=context_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) t c -> b c t h w", h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, "b hw t c -> (b hw) c t").contiguous() + x = self.proj_out(x) + x = rearrange(x, "(b h w) c t -> b c t h w", b=b, h=h, w=w).contiguous() + + return x + x_in + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/encoders/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/encoders/condition.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/encoders/condition.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe1c5ed6c43d7b2d62c0ca524d38cbf23205261 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/encoders/condition.py @@ -0,0 +1,512 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +import kornia +import open_clip +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel +from videogen_hub.pipelines.t2v_turbo.lvdm.common import autocast +from videogen_hub.pipelines.t2v_turbo.utils.utils import count_params + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0.0 and not disable_dropout: + mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = ( + self.n_classes - 1 + ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + ): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer( + input_ids=tokens, output_hidden_states=self.layer == "hidden" + ) + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class ClipImageEmbedder(nn.Module): + def __init__( + self, + model, + jit=False, + device="cuda" if torch.cuda.is_available() else "cpu", + antialias=True, + ucg_rate=0.0, + ): + super().__init__() + from clip import load as load_clip + + self.model, _ = load_clip(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # re-normalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x, no_dropout=False): + # x is assumed to be in range [-1,1] + out = self.model.encode_image(self.preprocess(x)) + out = out.to(x.dtype) + if self.ucg_rate > 0.0 and not no_dropout: + out = ( + torch.bernoulli( + (1.0 - self.ucg_rate) * torch.ones(out.shape[0], device=out.device) + )[:, None] + * out + ) + return out + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + + LAYERS = [ + # "pooled", + "last", + "penultimate", + ] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + ): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, device=torch.device("cpu") + ) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + self.device = self.model.positional_embedding.device + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="pooled", + antialias=True, + ucg_rate=0.0, + ): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) + del model.transformer + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + if self.ucg_rate > 0.0 and not no_dropout: + z = ( + torch.bernoulli( + (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) + )[:, None] + * z + ) + return z + + def encode_with_vision_transformer(self, img): + img = self.preprocess(img) + x = self.model.visual(img) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + freeze=True, + layer="pooled", + antialias=True, + ): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) + del model.transformer + self.model = model + self.device = device + + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + def forward(self, image, no_dropout=False): + ## image: b c h w + z = self.encode_with_vision_transformer(image) + return z + + def encode_with_vision_transformer(self, x): + x = self.preprocess(x) + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.model.visual.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape( + x.shape[0], + x.shape[1], + self.model.visual.grid_size[0], + self.model.visual.patch_size[0], + self.model.visual.grid_size[1], + self.model.visual.patch_size[1], + ) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape( + x.shape[0], + self.model.visual.grid_size[0] * self.model.visual.grid_size[1], + -1, + ) + x = self.model.visual.patchnorm_pre_ln(x) + x = self.model.visual.conv1(x) + else: + x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [ + self.model.visual.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.model.visual.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.model.visual.patch_dropout(x) + x = self.model.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.model.visual.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + return x + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder( + clip_version, device, max_length=clip_max_length + ) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." + ) + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/encoders/ip_resampler.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/encoders/ip_resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3f6e5f978588c4d0c6636d70c4caf3a02e6f34 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/encoders/ip_resampler.py @@ -0,0 +1,148 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math +import torch +import torch.nn as nn + + +class ImageProjModel(nn.Module): + """Projection Model""" + + def __init__( + self, + cross_attention_dim=1024, + clip_embeddings_dim=1024, + clip_extra_context_tokens=4, + ): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + # embeds = image_embeds + embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose( + -2, -1 + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/networks/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/networks/ae_modules.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/networks/ae_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1f54832ff95c076cd7ade1a1091f6b0799ed1c63 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/networks/ae_modules.py @@ -0,0 +1,1025 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import numpy as np +import torch.nn as nn +from einops import rearrange +from videogen_hub.pipelines.t2v_turbo.utils.utils import instantiate_from_config +from videogen_hub.pipelines.t2v_turbo.lvdm.modules.attention import LinearAttention + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) # bcl + q = q.permute(0, 2, 1) # bcl -> blc l=hw + k = k.reshape(b, c, h * w) # bcl + + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # print(f'encoder-input={x.shape}') + # downsampling + hs = [self.conv_in(x)] + # print(f'encoder-conv in feat={hs[0].shape}') + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + # print(f'encoder-down feat={h.shape}') + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + # print(f'encoder-downsample (input)={hs[-1].shape}') + hs.append(self.down[i_level].downsample(hs[-1])) + # print(f'encoder-downsample (output)={hs[-1].shape}') + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + # print(f'encoder-mid1 feat={h.shape}') + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # print(f'encoder-mid2 feat={h.shape}') + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # print(f'end feat={h.shape}') + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print( + "AE working on z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # print(f'decoder-input={z.shape}') + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + # print(f'decoder-conv in feat={h.shape}') + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # print(f'decoder-mid feat={h.shape}') + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + # print(f'decoder-up feat={h.shape}') + if i_level != 0: + h = self.up[i_level].upsample(h) + # print(f'decoder-upsample feat={h.shape}') + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # print(f'decoder-conv_out feat={h.shape}') + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/networks/openaimodel3d.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/networks/openaimodel3d.py new file mode 100644 index 0000000000000000000000000000000000000000..4bba9425d522cda972c16a2926768fbb3bca7c28 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/networks/openaimodel3d.py @@ -0,0 +1,710 @@ +from functools import partial +from abc import abstractmethod +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +from lvdm.models.utils_diffusion import timestep_embedding +from lvdm.common import checkpoint +from lvdm.basics import zero_module, conv_nd, linear, avg_pool_nd, normalization +from lvdm.modules.attention import SpatialTransformer, TemporalTransformer + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, batch_size=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb, batch_size) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + elif isinstance(layer, TemporalTransformer): + x = rearrange(x, "(b f) c h w -> b c f h w", b=batch_size) + x = layer(x, context) + x = rearrange(x, "b c f h w -> (b f) c h w") + else: + x = layer( + x, + ) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + use_conv=False, + up=False, + down=False, + use_temporal_conv=False, + tempspatial_aware=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock( + self.out_channels, + self.out_channels, + dropout=0.1, + spatial_aware=tempspatial_aware, + ) + + def forward(self, x, emb, batch_size=None): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + input_tuple = ( + x, + emb, + ) + if batch_size: + forward_batchsize = partial(self._forward, batch_size=batch_size) + return checkpoint( + forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint + ) + return checkpoint( + self._forward, input_tuple, self.parameters(), self.use_checkpoint + ) + + def _forward( + self, + x, + emb, + batch_size=None, + ): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv and batch_size: + h = rearrange(h, "(b t) c h w -> b c t h w", b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, "b c t h w -> (b t) c h w") + return h + + +class TemporalConvBlock(nn.Module): + """ + Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py + """ + + def __init__( + self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False + ): + super(TemporalConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + self.in_channels = in_channels + self.out_channels = out_channels + kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 3) + padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 1) + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_channels), + nn.SiLU(), + nn.Conv3d(in_channels, out_channels, kernel_shape, padding=padding_shape), + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_channels), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_channels, in_channels, kernel_shape, padding=padding_shape), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_channels), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_channels, in_channels, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_channels), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_channels, in_channels, (3, 1, 1), padding=(1, 0, 0)), + ) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + return x + identity + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: in_channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0.0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + context_dim=None, + use_scale_shift_norm=False, + resblock_updown=False, + num_heads=-1, + num_head_channels=-1, + transformer_depth=1, + use_linear=False, + use_checkpoint=False, + temporal_conv=False, + tempspatial_aware=False, + temporal_attention=True, + temporal_selfatt_only=True, + use_relative_position=True, + use_causal_attention=False, + temporal_length=None, + use_fp16=False, + addition_attention=False, + use_image_attention=False, + temporal_transformer_depth=1, + fps_cond=False, + time_cond_proj_dim=None, + ): + super(UNetModel, self).__init__() + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.temporal_attention = temporal_attention + time_embed_dim = model_channels * 4 + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.addition_attention = addition_attention + self.use_image_attention = use_image_attention + self.fps_cond = fps_cond + self.time_cond_proj_dim = time_cond_proj_dim + + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + if self.fps_cond: + self.fps_embedding = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + if time_cond_proj_dim is not None: + self.time_cond_proj = nn.Linear( + time_cond_proj_dim, model_channels, bias=False + ) + else: + self.time_cond_proj = None + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + if self.addition_attention: + self.init_attn = TimestepEmbedSequential( + TemporalTransformer( + model_channels, + n_heads=8, + d_head=num_head_channels, + depth=transformer_depth, + context_dim=context_dim, + use_checkpoint=use_checkpoint, + only_self_att=temporal_selfatt_only, + causal_attention=use_causal_attention, + relative_position=use_relative_position, + temporal_length=temporal_length, + ) + ) + + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + disable_self_attn=False, + img_cross_attention=self.use_image_attention, + ) + ) + if self.temporal_attention: + layers.append( + TemporalTransformer( + ch, + num_heads, + dim_head, + depth=temporal_transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + only_self_att=temporal_selfatt_only, + causal_attention=use_causal_attention, + relative_position=use_relative_position, + temporal_length=temporal_length, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv, + ), + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + disable_self_attn=False, + img_cross_attention=self.use_image_attention, + ), + ] + if self.temporal_attention: + layers.append( + TemporalTransformer( + ch, + num_heads, + dim_head, + depth=temporal_transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + only_self_att=temporal_selfatt_only, + causal_attention=use_causal_attention, + relative_position=use_relative_position, + temporal_length=temporal_length, + ) + ) + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv, + ) + ) + self.middle_block = TimestepEmbedSequential(*layers) + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + disable_self_attn=False, + img_cross_attention=self.use_image_attention, + ) + ) + if self.temporal_attention: + layers.append( + TemporalTransformer( + ch, + num_heads, + dim_head, + depth=temporal_transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + only_self_att=temporal_selfatt_only, + causal_attention=use_causal_attention, + relative_position=use_relative_position, + temporal_length=temporal_length, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def forward( + self, + x, + timesteps, + context=None, + features_adapter=None, + fps=16, + timestep_cond=None, + **kwargs + ): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + if timestep_cond is not None: + t_emb = t_emb + self.time_cond_proj(timestep_cond) + emb = self.time_embed(t_emb) + + if self.fps_cond: + if type(fps) == int: + fps = torch.full_like(timesteps, fps) + fps_emb = timestep_embedding(fps, self.model_channels, repeat_only=False) + emb += self.fps_embedding(fps_emb) + + b, _, t, _, _ = x.shape + ## repeat t times for context [(b t) 77 768] & time embedding + context = context.repeat_interleave(repeats=t, dim=0) + emb = emb.repeat_interleave(repeats=t, dim=0) + + ## always in shape (b t) c h w, except for temporal layer + x = rearrange(x, "b c t h w -> (b t) c h w") + + h = x.type(self.dtype) + adapter_idx = 0 + hs = [] + for id, module in enumerate(self.input_blocks): + h = module(h, emb, context=context, batch_size=b) + if id == 0 and self.addition_attention: + h = self.init_attn(h, emb, context=context, batch_size=b) + ## plug-in adapter features + if ((id + 1) % 3 == 0) and features_adapter is not None: + h = h + features_adapter[adapter_idx] + adapter_idx += 1 + hs.append(h) + if features_adapter is not None: + assert len(features_adapter) == adapter_idx, "Wrong features_adapter" + + h = self.middle_block(h, emb, context=context, batch_size=b) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context=context, batch_size=b) + h = h.type(x.dtype) + y = self.out(h) + + # reshape back to (b c t h w) + y = rearrange(y, "(b t) c h w -> b c t h w", b=b) + return y diff --git a/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/x_transformer.py b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..455cab7beef36cc3827f0f1ffa8ecd4f32c240f8 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/lvdm/modules/x_transformer.py @@ -0,0 +1,704 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" + +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat +import torch +from torch import nn, einsum +import torch.nn.functional as F + +# constants +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) + +LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = ( + torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + + offset + ) + sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + + return inner + + +def not_equals(val): + def inner(x): + return x != val + + return inner + + +def equals(val): + def inner(x): + return x == val + + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key( + partial(string_begins_with, prefix), d + ) + kwargs_without_prefix = dict( + map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())) + ) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d") + ) + + return gated_output.reshape_as(x) + + +# feedforward + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, + ): + super().__init__() + if use_entmax15: + raise NotImplementedError( + "Check out entmax activation instead of softmax activation!" + ) + self.scale = dim_head**-0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + # self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) + if on_attn + else nn.Linear(inner_dim, dim) + ) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None, + ): + b, n, _, h, talking_heads, device = ( + *x.shape, + self.heads, + self.talking_heads, + x.device, + ) + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default( + k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool() + ) + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b j -> b () () j") + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map( + lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v) + ) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum( + "b h i j, h k -> b k i j", dots, self.pre_softmax_proj + ).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum( + "b h i j, h k -> b k i j", attn, self.post_softmax_proj + ).contiguous() + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs, + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) + attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs) + + dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = ( + FixedPositionalEmbedding(dim) if position_infused_attn else None + ) + self.rotary_pos_emb = always(None) + + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), "number of relative position buckets must be less than the relative position max distance" + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ("a", "c", "f") + elif cross_attend and only_cross: + default_block = ("c", "f") + else: + default_block = ("a", "f") + + if macaron: + default_block = ("f",) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, "par ratio out of range" + default_block = tuple(filter(not_equals("f"), default_block)) + par_attn = par_depth // par_ratio + depth_cut = ( + par_depth * 2 // 3 + ) # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert ( + len(default_block) <= par_width + ), "default block is too large for par_ratio" + par_block = default_block + ("f",) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ("f",) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert ( + sandwich_coef > 0 and sandwich_coef <= depth + ), "sandwich coefficient should be less than the depth" + layer_types = ( + ("a",) * sandwich_coef + + default_block * (depth - sandwich_coef) + + ("f",) * sandwich_coef + ) + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals("a"), layer_types))) + + for layer_type in self.layer_types: + if layer_type == "a": + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == "c": + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == "f": + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f"invalid layer type {layer_type}") + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False, + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate( + zip(self.layer_types, self.layers) + ): + is_last = ind == (len(self.layers) - 1) + + if layer_type == "a": + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == "a": + out, inter = block( + x, + mask=mask, + sinusoidal_emb=self.pia_pos_emb, + rel_pos=self.rel_pos, + prev_attn=prev_attn, + mem=layer_mem, + ) + elif layer_type == "c": + out, inter = block( + x, + context=context, + mask=mask, + context_mask=context_mask, + prev_attn=prev_cross_attn, + ) + elif layer_type == "f": + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ("a", "c"): + intermediates.append(inter) + + if layer_type == "a" and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == "c" and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on encoder" + super().__init__(causal=False, **kwargs) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, + ): + super().__init__() + assert isinstance( + attn_layers, AttentionLayers + ), "attention layers must be one of Encoder or Decoder" + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = ( + nn.Linear(dim, num_tokens) + if not tie_embedding + else lambda t: t @ self.token_emb.weight.t() + ) + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, "num_memory_tokens"): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs, + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, "n d -> b n d", b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers( + x, mask=mask, mems=mems, return_hiddens=True, **kwargs + ) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = ( + list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) + if exists(mems) + else hiddens + ) + new_mems = list( + map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems) + ) + return out, new_mems + + if return_attn: + attn_maps = list( + map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates) + ) + return out, attn_maps + + return out diff --git a/src/videogen_hub/pipelines/t2v_turbo/model_scope/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/model_scope/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/model_scope/unet_3d_blocks.py b/src/videogen_hub/pipelines/t2v_turbo/model_scope/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..064186a228015a79260a9177f200a0f0179210d9 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/model_scope/unet_3d_blocks.py @@ -0,0 +1,886 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied from https://github.com/ExponentialML/Text-To-Video-Finetuning/blob/main/models/unet_3d_blocks.py +import torch +import torch.utils.checkpoint as checkpoint +from torch import nn +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D + +try: + from diffusers.models.transformer_2d import Transformer2DModel + from diffusers.models.transformer_temporal import TransformerTemporalModel +except: + from diffusers.models.transformers.transformer_2d import Transformer2DModel + from diffusers.models import TransformerTemporalModel + +# Assign gradient checkpoint function to simple variable for readability. +g_c = checkpoint.checkpoint + + +def is_video(num_frames, only_video=True): + if num_frames == 1 and not only_video: + return False + return num_frames > 1 + + +def custom_checkpoint(module, mode=None): + if mode == None: + raise ValueError('Mode for gradient checkpointing cannot be none.') + + custom_forward = None + + if mode == 'resnet': + def custom_forward(hidden_states, temb): + inputs = module(hidden_states, temb) + return inputs + + if mode == 'attn': + def custom_forward( + hidden_states, + encoder_hidden_states=None, + cross_attention_kwargs=None, + attention_mask=None, + ): + inputs = module( + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask + ) + return inputs.sample + + if mode == 'temp': + # If inputs are not None, we can assume that this was a single image. + # Otherwise, do temporal convolutions / attention. + def custom_forward(hidden_states, num_frames=None): + if not is_video(num_frames): + return hidden_states + else: + inputs = module( + hidden_states, + num_frames=num_frames + ) + if isinstance(module, TransformerTemporalModel): + return inputs.sample + else: + return inputs + + return custom_forward + + +def transformer_g_c(transformer, sample, num_frames): + sample = g_c(custom_checkpoint(transformer, mode='temp'), + sample, num_frames, use_reentrant=False, + ) + return sample + + +def cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=False, + attention_mask=None, +): + def ordered_g_c(idx): + + # Self and CrossAttention + if idx == 0: + return g_c(custom_checkpoint(attn, mode='attn'), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + use_reentrant=False + ) + + # Temporal Self and CrossAttention + if idx == 1: + return g_c(custom_checkpoint(temp_attn, mode='temp'), + hidden_states, + num_frames, + use_reentrant=False + ) + + # Resnets + if idx == 2: + return g_c(custom_checkpoint(resnet, mode='resnet'), + hidden_states, + temb, + use_reentrant=False + ) + + # Temporal Convolutions + if idx == 3: + return g_c(custom_checkpoint(temp_conv, mode='temp'), + hidden_states, + num_frames, + use_reentrant=False + ) + + # Here we call the function depending on the order in which they are called. + # For some layers, the orders are different, so we access the appropriate one by index. + + if not inverse_temp: + for idx in [0, 1, 2, 3]: + hidden_states = ordered_g_c(idx) + else: + for idx in [2, 3, 0, 1]: + hidden_states = ordered_g_c(idx) + + return hidden_states + + +def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames): + hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'), + hidden_states, + temb, + use_reentrant=False + ) + hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'), + hidden_states, + num_frames, + use_reentrant=False + ) + return hidden_states + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=False, + ): + super().__init__() + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1 + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + if self.gradient_checkpointing: + hidden_states = up_down_g_c( + self.resnets[0], + self.temp_convs[0], + hidden_states, + temb, + num_frames + ) + else: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + self.gradient_checkpointing = False + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + if self.gradient_checkpointing: + hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + self.gradient_checkpointing = False + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/src/videogen_hub/pipelines/t2v_turbo/model_scope/unet_3d_condition.py b/src/videogen_hub/pipelines/t2v_turbo/model_scope/unet_3d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..310f9c5c0e26040200988ee6cc43402ad3b5b2d7 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/model_scope/unet_3d_condition.py @@ -0,0 +1,508 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/ExponentialML/Text-To-Video-Finetuning/blob/main/models/unet_3d_condition.py +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +try: + from diffusers.models.transformer_temporal import TransformerTemporalModel +except: + from diffusers.models.transformers.transformer_temporal import TransformerTemporalModelOutput + from diffusers.models import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + transformer_g_c +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + time_cond_proj_dim: Optional[int] = None, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + ): + super().__init__() + + self.sample_size = sample_size + self.gradient_checkpointing = False + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + cond_proj_dim=time_cond_proj_dim, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=False, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, value=False): + self.gradient_checkpointing = value + self.mid_block.gradient_checkpointing = value + for module in self.down_blocks + self.up_blocks: + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + if num_frames > 1: + if self.gradient_checkpointing: + sample = transformer_g_c(self.transformer_in, sample, num_frames) + else: + sample = self.transformer_in(sample, num_frames=num_frames).sample + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/t2v_turbo/pipeline/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/pipeline/t2v_turbo_ms_pipeline.py b/src/videogen_hub/pipelines/t2v_turbo/pipeline/t2v_turbo_ms_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0cdb90341ed3e20f4c9dcd00e384d68ca63f895e --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/pipeline/t2v_turbo_ms_pipeline.py @@ -0,0 +1,221 @@ +import torch +from diffusers import DiffusionPipeline + +from typing import List, Optional, Tuple, Union, Dict, Any + +from diffusers import logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.models import AutoencoderKL +from transformers import CLIPTokenizer, CLIPTextModel +from ..scheduler.t2v_turbo_scheduler import T2VTurboScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class T2VTurboMSPipeline(DiffusionPipeline): + def __init__( + self, + unet, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + scheduler: T2VTurboScheduler, + ): + super().__init__() + + self.register_modules( + unet=unet, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 8 + + def _encode_prompt( + self, + prompt, + device, + num_videos_per_prompt, + prompt_embeds: None, + ): + r""" + Encodes the prompt into text encoder hidden states. + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + if prompt_embeds is None: + with torch.no_grad(): + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) + prompt_embeds = self.text_encoder(text_input_ids)[0] + + prompt_embeds = prompt_embeds.to(device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_videos_per_prompt, seq_len, -1 + ) + + # Don't need to get uncond prompt embedding because of LCM Guided Distillation + return prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + frames, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + Args: + timesteps: torch.Tensor: generate embedding vectors at these timesteps + embedding_dim: int: dimension of the embeddings to generate + dtype: data type of the generated embeddings + Returns: + embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = 256, + width: Optional[int] = 256, + frames: int = 16, + guidance_scale: float = 7.5, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + num_inference_steps: int = 4, + lcm_origin_steps: int = 50, + prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + ): + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG) + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_videos_per_prompt, + prompt_embeds=prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, lcm_origin_steps) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variable + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + bs = batch_size * num_videos_per_prompt + + # 6. Get Guidance Scale Embedding + w = torch.tensor(guidance_scale).repeat(bs) + w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device) + + # 7. LCM MultiStep Sampling Loop: + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + + ts = torch.full((bs,), t, device=device, dtype=torch.long) + + # model prediction (v-prediction, eps, x) + model_pred = self.unet( + latents, + ts, + timestep_cond=w_embedding, + encoder_hidden_states=prompt_embeds.float(), + ).sample + # compute the previous noisy sample x_t -> x_t-1 + latents, denoised = self.scheduler.step( + model_pred, i, t, latents, return_dict=False + ) + + progress_bar.update() + + if not output_type == "latent": + t = denoised.shape[2] + z = denoised.to(self.vae.dtype) / self.vae.config.scaling_factor + videos = torch.cat( + [self.vae.decode(z[:, :, i])[0].unsqueeze(2) for i in range(t)], + dim=2, + ) + else: + videos = denoised + + return videos diff --git a/src/videogen_hub/pipelines/t2v_turbo/pipeline/t2v_turbo_vc2_pipeline.py b/src/videogen_hub/pipelines/t2v_turbo/pipeline/t2v_turbo_vc2_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..93a7632c5275387aaebf506b71c2ca0aa7f466dc --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/pipeline/t2v_turbo_vc2_pipeline.py @@ -0,0 +1,214 @@ +import torch +from diffusers import DiffusionPipeline + +from typing import List, Optional, Union, Dict, Any + +from diffusers import logging +from diffusers.utils.torch_utils import randn_tensor +from ..lvdm.models.ddpm3d import LatentDiffusion +from ..scheduler.t2v_turbo_scheduler import T2VTurboScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class T2VTurboVC2Pipeline(DiffusionPipeline): + def __init__( + self, + pretrained_t2v: LatentDiffusion, + scheduler: T2VTurboScheduler, + model_config: Dict[str, Any] = None, + ): + super().__init__() + + self.register_modules( + pretrained_t2v=pretrained_t2v, + scheduler=scheduler, + ) + self.vae = pretrained_t2v.first_stage_model + self.unet = pretrained_t2v.model.diffusion_model + self.text_encoder = pretrained_t2v.cond_stage_model + + self.model_config = model_config + self.vae_scale_factor = 8 + + def _encode_prompt( + self, + prompt, + device, + num_videos_per_prompt, + prompt_embeds: None, + ): + r""" + Encodes the prompt into text encoder hidden states. + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + if prompt_embeds is None: + + prompt_embeds = self.text_encoder(prompt) + + prompt_embeds = prompt_embeds.to(device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_videos_per_prompt, seq_len, -1 + ) + + # Don't need to get uncond prompt embedding because of LCM Guided Distillation + return prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + frames, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + Args: + timesteps: torch.Tensor: generate embedding vectors at these timesteps + embedding_dim: int: dimension of the embeddings to generate + dtype: data type of the generated embeddings + Returns: + embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = 320, + width: Optional[int] = 512, + frames: int = 16, + fps: int = 16, + guidance_scale: float = 7.5, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + num_inference_steps: int = 4, + lcm_origin_steps: int = 50, + prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + ): + unet_config = self.model_config["params"]["unet_config"] + # 0. Default height and width to unet + frames = self.pretrained_t2v.temporal_length if frames < 0 else frames + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG) + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_videos_per_prompt, + prompt_embeds=prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, lcm_origin_steps) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variable + num_channels_latents = unet_config["params"]["in_channels"] + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + bs = batch_size * num_videos_per_prompt + + # 6. Get Guidance Scale Embedding + w = torch.tensor(guidance_scale).repeat(bs) + w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device) + + # 7. LCM MultiStep Sampling Loop: + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + + ts = torch.full((bs,), t, device=device, dtype=torch.long) + + # model prediction (v-prediction, eps, x) + context = {"context": torch.cat([prompt_embeds.float()], 1), "fps": fps} + model_pred = self.unet( + latents, + ts, + **context, + timestep_cond=w_embedding.to(self.dtype), + ) + # compute the previous noisy sample x_t -> x_t-1 + latents, denoised = self.scheduler.step( + model_pred, i, t, latents, return_dict=False + ) + + # # call the callback, if provided + # if i == len(timesteps) - 1: + progress_bar.update() + + if not output_type == "latent": + videos = self.pretrained_t2v.decode_first_stage_2DAE(denoised) + else: + videos = denoised + + return videos diff --git a/src/videogen_hub/pipelines/t2v_turbo/scheduler/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/scheduler/t2v_turbo_scheduler.py b/src/videogen_hub/pipelines/t2v_turbo/scheduler/t2v_turbo_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe8dfbbf18f0e77b232287de24d81b2ba8593ab --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/scheduler/t2v_turbo_scheduler.py @@ -0,0 +1,518 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers import ConfigMixin, SchedulerMixin +from diffusers.configuration_utils import register_to_config +from diffusers.utils import BaseOutput + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class T2VTurboSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + denoised: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class T2VTurboScheduler(SchedulerMixin, ConfigMixin): + """ + `T2VTurboScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + # _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + linear_start: float = 0.00085, + linear_end: float = 0.012, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + assert beta_schedule == "scaled_linear" + assert trained_betas is None + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace( + linear_start, linear_end, num_train_timesteps, dtype=torch.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace( + linear_start**0.5, + linear_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}" + ) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = ( + torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + ) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy( + np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64) + ) + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Optional[int] = None + ) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * ( + 1 - alpha_prod_t / alpha_prod_t_prev + ) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, + num_inference_steps: int, + lcm_origin_steps: int, + device: Union[str, torch.device] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # LCM Timesteps Setting: # Linear Spacing + c = self.config.num_train_timesteps // lcm_origin_steps + lcm_origin_timesteps = ( + np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 + ) # LCM Training Steps Schedule + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + timesteps = lcm_origin_timesteps[::-skipping_step][ + :num_inference_steps + ] # LCM Inference Steps Schedule + + self.timesteps = torch.from_numpy(timesteps.copy()).to(device) + + ## From VideoCrafter 2 + + def get_scalings_for_boundary_condition_discrete(self, t): + self.sigma_data = 0.5 # Default: 0.5 + + # By dividing 0.1: This is almost a delta function at t=0. + c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2) + c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timeindex: int, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[T2VTurboSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.FloatTensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # 1. get previous step value + prev_timeindex = timeindex + 1 + if prev_timeindex < len(self.timesteps): + prev_timestep = self.timesteps[prev_timeindex] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Different Parameterization: + parameterization = self.config.prediction_type + + if parameterization == "epsilon": # noise-prediction + pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + + elif parameterization == "sample": # x-prediction + pred_x0 = model_output + + elif parameterization == "v_prediction": # v-prediction + pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + + # 4. Denoise model output using boundary conditions + denoised = c_out * pred_x0 + c_skip * sample + + # 5. Sample z ~ N(0, I), For MultiStep Inference + # Noise is not used for one-step sampling. + if len(self.timesteps) > 1: + noise = torch.randn(model_output.shape).to(model_output.device) + prev_sample = ( + alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + ) + else: + prev_sample = denoised + + if not return_dict: + return (prev_sample, denoised) + + return T2VTurboSchedulerOutput(prev_sample=prev_sample, denoised=denoised) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to( + device=original_samples.device, dtype=original_samples.dtype + ) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = ( + sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + ) + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, + sample: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to( + device=sample.device, dtype=sample.dtype + ) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/videogen_hub/pipelines/t2v_turbo/utils/__init__.py b/src/videogen_hub/pipelines/t2v_turbo/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/t2v_turbo/utils/common_utils.py b/src/videogen_hub/pipelines/t2v_turbo/utils/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d74a7cba1fab65dc9d8a6c24cdad73c892f55ba4 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/utils/common_utils.py @@ -0,0 +1,385 @@ +import ast +import gc +import torch + +from collections import OrderedDict + +from diffusers.models.attention_processor import AttnProcessor2_0 +from diffusers.models.attention import BasicTransformerBlock +import wandb + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def is_attn(name): + return "attn1" or "attn2" == name.split(".")[-1] + + +def set_processors(attentions): + for attn in attentions: + attn.set_processor(AttnProcessor2_0()) + + +def set_torch_2_attn(unet): + optim_count = 0 + + for name, module in unet.named_modules(): + if is_attn(name): + if isinstance(module, torch.nn.ModuleList): + for m in module: + if isinstance(m, BasicTransformerBlock): + set_processors([m.attn1, m.attn2]) + optim_count += 1 + if optim_count > 0: + print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") + + +# From LatentConsistencyModel.get_guidance_scale_embedding +def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +# From LCMScheduler.get_scalings_for_boundary_condition_discrete +def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + +# Compare LCMScheduler.step, Step 4 +def get_predicted_original_sample( + model_output, timesteps, sample, prediction_type, alphas, sigmas +): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output + elif prediction_type == "v_prediction": + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_x_0 + + +# Based on step 4 in DDIMScheduler.step +def get_predicted_noise( + model_output, timesteps, sample, prediction_type, alphas, sigmas +): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + +# From LatentConsistencyModel.get_guidance_scale_embedding +def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +# From LCMScheduler.get_scalings_for_boundary_condition_discrete +def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + +# Compare LCMScheduler.step, Step 4 +def get_predicted_original_sample( + model_output, timesteps, sample, prediction_type, alphas, sigmas +): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output + elif prediction_type == "v_prediction": + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_x_0 + + +# Based on step 4 in DDIMScheduler.step +def get_predicted_noise( + model_output, timesteps, sample, prediction_type, alphas, sigmas +): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + +def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): + extra_params = extra_params if len(extra_params.keys()) > 0 else None + return { + "model": model, + "condition": condition, + "extra_params": extra_params, + "is_lora": is_lora, + "negation": negation, + } + + +def create_optim_params(name="param", params=None, lr=5e-6, extra_params=None): + params = {"name": name, "params": params, "lr": lr} + if extra_params is not None: + for k, v in extra_params.items(): + params[k] = v + + return params + + +def create_optimizer_params(model_list, lr): + import itertools + + optimizer_params = [] + + for optim in model_list: + model, condition, extra_params, is_lora, negation = optim.values() + # Check if we are doing LoRA training. + if is_lora and condition and isinstance(model, list): + params = create_optim_params( + params=itertools.chain(*model), extra_params=extra_params + ) + optimizer_params.append(params) + continue + + if is_lora and condition and not isinstance(model, list): + for n, p in model.named_parameters(): + if "lora" in n: + params = create_optim_params(n, p, lr, extra_params) + optimizer_params.append(params) + continue + + # If this is true, we can train it. + if condition: + for n, p in model.named_parameters(): + should_negate = "lora" in n and not is_lora + if should_negate: + continue + + params = create_optim_params(n, p, lr, extra_params) + optimizer_params.append(params) + + return optimizer_params + + +def handle_trainable_modules( + model, trainable_modules=None, is_enabled=True, negation=None +): + acc = [] + unfrozen_params = 0 + + if trainable_modules is not None: + unlock_all = any([name == "all" for name in trainable_modules]) + if unlock_all: + model.requires_grad_(True) + unfrozen_params = len(list(model.parameters())) + else: + model.requires_grad_(False) + for name, param in model.named_parameters(): + for tm in trainable_modules: + if all([tm in name, name not in acc, "lora" not in name]): + param.requires_grad_(is_enabled) + acc.append(name) + unfrozen_params += 1 + + +def huber_loss(pred, target, huber_c=0.001): + loss = torch.sqrt((pred.float() - target.float()) ** 2 + huber_c**2) - huber_c + return loss.mean() + + +@torch.no_grad() +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def log_validation_video(pipeline, args, accelerator, save_fps): + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + validation_prompts = [ + "An astronaut riding a horse.", + "Darth vader surfing in waves.", + "Robot dancing in times square.", + "Clown fish swimming through the coral reef.", + "A child excitedly swings on a rusty swing set, laughter filling the air.", + "With the style of van gogh, A young couple dances under the moonlight by the lake.", + "A young woman with glasses is jogging in the park wearing a pink headband.", + "Impressionist style, a yellow rubber duck floating on the wave on the sunset", + ] + + video_logs = [] + + for _, prompt in enumerate(validation_prompts): + with torch.autocast("cuda"): + videos = pipeline( + prompt=prompt, + frames=args.n_frames, + num_inference_steps=4, + num_videos_per_prompt=2, + generator=generator, + ) + videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0 + videos = (videos * 255).to(torch.uint8).permute(0, 2, 1, 3, 4).cpu().numpy() + video_logs.append({"validation_prompt": prompt, "videos": videos}) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + formatted_videos = [] + for log in video_logs: + videos = log["videos"] + validation_prompt = log["validation_prompt"] + for video in videos: + video = wandb.Video(video, caption=validation_prompt, fps=save_fps) + formatted_videos.append(video) + + tracker.log({f"validation": formatted_videos}) + + del pipeline + gc.collect() + + +def tuple_type(s): + if isinstance(s, tuple): + return s + value = ast.literal_eval(s) + if isinstance(value, tuple): + return value + raise TypeError("Argument must be a tuple") + + +def load_model_checkpoint(model, ckpt): + def load_checkpoint(model, ckpt, full_strict): + state_dict = torch.load(ckpt, map_location="cpu") + if "state_dict" in list(state_dict.keys()): + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict, strict=full_strict) + del state_dict + gc.collect() + return model + + load_checkpoint(model, ckpt, full_strict=True) + print(">>> model checkpoint loaded.") + return model diff --git a/src/videogen_hub/pipelines/t2v_turbo/utils/lora.py b/src/videogen_hub/pipelines/t2v_turbo/utils/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6212409ed4aedfe08801f8b82e94dfc8219ae3 --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/utils/lora.py @@ -0,0 +1,1349 @@ +import json +import math +from itertools import groupby +import os +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import numpy as np +import PIL +import torch +import torch.nn as nn +import torch.nn.functional as F + +from safetensors.torch import safe_open +from safetensors.torch import save_file as safe_save + +safetensors_available = True + + +class LoraInjectedLinear(nn.Module): + def __init__( + self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 + ): + super().__init__() + + if r > min(in_features, out_features): + # raise ValueError( + # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + # ) + print( + f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}" + ) + r = min(in_features, out_features) + + self.r = r + self.linear = nn.Linear(in_features, out_features, bias) + self.lora_down = nn.Linear(in_features, r, bias=False) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Linear(r, out_features, bias=False) + self.scale = scale + self.selector = nn.Identity() + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.linear(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Linear(self.r, self.r, bias=False) + self.selector.weight.data = torch.diag(diag) + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +class LoraInjectedConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + bias: bool = True, + r: int = 4, + dropout_p: float = 0.1, + scale: float = 1.0, + ): + super().__init__() + if r > min(in_channels, out_channels): + print( + f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}" + ) + r = min(in_channels, out_channels) + + self.r = r + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + self.lora_down = nn.Conv2d( + in_channels=in_channels, + out_channels=r, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Conv2d( + in_channels=r, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector = nn.Identity() + self.scale = scale + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.conv(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Conv2d( + in_channels=self.r, + out_channels=self.r, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector.weight.data = torch.diag(diag) + + # same device + dtype as lora_up + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +class LoraInjectedConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int, int], # (3, 1, 1) + padding: Tuple[int, int, int], # (1, 0, 0) + bias: bool = False, + r: int = 4, + dropout_p: float = 0, + scale: float = 1.0, + ): + super().__init__() + if r > min(in_channels, out_channels): + print( + f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}" + ) + r = min(in_channels, out_channels) + + self.r = r + self.kernel_size = kernel_size + self.padding = padding + self.conv = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + ) + + self.lora_down = nn.Conv3d( + in_channels=in_channels, + out_channels=r, + kernel_size=kernel_size, + bias=False, + padding=padding, + ) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Conv3d( + in_channels=r, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector = nn.Identity() + self.scale = scale + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.conv(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Conv3d( + in_channels=self.r, + out_channels=self.r, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector.weight.data = torch.diag(diag) + + # same device + dtype as lora_up + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} + +UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} + +TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} + +TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} + +DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE + +EMBED_FLAG = "" + + +def _find_children( + model, + search_class: List[Type[nn.Module]] = [nn.Linear], +): + """ + Find all modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for parent in model.modules(): + for name, module in parent.named_children(): + if any([isinstance(module, _class) for _class in search_class]): + yield parent, name, module + + +def _find_modules_v2( + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [ + LoraInjectedLinear, + LoraInjectedConv2d, + LoraInjectedConv3d, + ], +): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + if ancestor_class is not None: + ancestors = ( + module + for module in model.modules() + if module.__class__.__name__ in ancestor_class + ) + else: + # this, incase you want to naively iterate over all modules. + ancestors = [module for module in model.modules()] + + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if any([isinstance(module, _class) for _class in search_class]): + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + while path: + parent = parent.get_submodule(path.pop(0)) + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and any( + [isinstance(parent, _class) for _class in exclude_children_of] + ): + continue + # Otherwise, yield it + yield parent, name, module + + +def _find_modules_old( + model, + ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear], +): + ret = [] + for _module in model.modules(): + if _module.__class__.__name__ in ancestor_class: + + for name, _child_module in _module.named_modules(): + if _child_module.__class__ in search_class: + ret.append((_module, name, _child_module)) + print(ret) + return ret + + +_find_modules = _find_modules_v2 + + +def inject_trainable_lora( + model: nn.Module, + target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt + verbose: bool = False, + dropout_p: float = 0.0, + scale: float = 1.0, +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear] + ): + weight = _child_module.weight + bias = _child_module.bias + if verbose: + print("LoRA Injection : injecting lora into ", name) + print("LoRA Injection : weight shape", weight.shape) + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + _module._modules[name] = _tmp + + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def inject_trainable_lora_extended( + model: nn.Module, + target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d] + ): + if _child_module.__class__ == nn.Linear: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv2d( + _child_module.in_channels, + _child_module.out_channels, + _child_module.kernel_size, + _child_module.stride, + _child_module.padding, + _child_module.dilation, + _child_module.groups, + _child_module.bias is not None, + r=r, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + + elif _child_module.__class__ == nn.Conv3d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv3d( + _child_module.in_channels, + _child_module.out_channels, + bias=_child_module.bias is not None, + kernel_size=_child_module.kernel_size, + padding=_child_module.padding, + r=r, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + else: + # ignore module which are not included in search_class + # For example: + # zeroscope_v2_576w model, which has and + continue + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + if bias is not None: + _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) + + _module._modules[name] = _tmp + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + param = loras.pop(0) + if isinstance(param, torch.FloatTensor): + _module._modules[name].lora_up.weight = nn.Parameter(param) + else: + _module._modules[name].lora_up.weight = param + + param = loras.pop(0) + if isinstance(param, torch.FloatTensor): + _module._modules[name].lora_down.weight = nn.Parameter(param) + else: + _module._modules[name].lora_down.weight = param + + # _module._modules[name].lora_up.weight = loras.pop(0) + # _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def inject_inferable_lora( + model, + lora_path="", + unet_replace_modules=["UNet3DConditionModel"], + text_encoder_replace_modules=["CLIPEncoderLayer"], + is_extended=False, + r=16, +): + from transformers.models.clip import CLIPTextModel + from diffusers import UNet3DConditionModel + + def is_text_model(f): + return "text_encoder" in f and isinstance(model.text_encoder, CLIPTextModel) + + def is_unet(f): + return "unet" in f and model.unet.__class__.__name__ == "UNet3DConditionModel" + + if os.path.exists(lora_path): + try: + for f in os.listdir(lora_path): + if f.endswith(".pt"): + lora_file = os.path.join(lora_path, f) + + if is_text_model(f): + monkeypatch_or_replace_lora( + model.text_encoder, + torch.load(lora_file), + target_replace_module=text_encoder_replace_modules, + r=r, + ) + print("Successfully loaded Text Encoder LoRa.") + continue + + if is_unet(f): + monkeypatch_or_replace_lora_extended( + model.unet, + torch.load(lora_file), + target_replace_module=unet_replace_modules, + r=r, + ) + print("Successfully loaded UNET LoRa.") + continue + + print( + "Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)" + ) + + except Exception as e: + print(e) + print("Couldn't inject LoRA's due to an error.") + + +def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): + + loras = [] + + for _m, _n, _child_module in _find_modules( + model, + target_replace_module, + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + loras.append((_child_module.lora_up, _child_module.lora_down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def extract_lora_as_tensor( + model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True +): + + loras = [] + + for _m, _n, _child_module in _find_modules( + model, + target_replace_module, + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + up, down = _child_module.realize_as_lora() + if as_fp16: + up = up.to(torch.float16) + down = down.to(torch.float16) + + loras.append((up, down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def save_lora_weight( + model, + path="./lora.pt", + target_replace_module=DEFAULT_TARGET_REPLACE, +): + weights = [] + for _up, _down in extract_lora_ups_down( + model, target_replace_module=target_replace_module + ): + weights.append(_up.weight.to("cpu").to(torch.float32)) + weights.append(_down.weight.to("cpu").to(torch.float32)) + + torch.save(weights, path) + + +def save_lora_as_json(model, path="./lora.json"): + weights = [] + for _up, _down in extract_lora_ups_down(model): + weights.append(_up.weight.detach().cpu().numpy().tolist()) + weights.append(_down.weight.detach().cpu().numpy().tolist()) + + import json + + with open(path, "w") as f: + json.dump(weights, f) + + +def save_safeloras_with_embeds( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Saves the Lora from multiple modules in a single safetensor file. + + modelmap is a dictionary of { + "module name": (module, target_replace_module) + } + """ + weights = {} + metadata = {} + + for name, (model, target_replace_module) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + for i, (_up, _down) in enumerate( + extract_lora_as_tensor(model, target_replace_module) + ): + rank = _down.shape[0] + + metadata[f"{name}:{i}:rank"] = str(rank) + weights[f"{name}:{i}:up"] = _up + weights[f"{name}:{i}:down"] = _down + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def save_safeloras( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + outpath="./lora.safetensors", +): + return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def convert_loras_to_safeloras_with_embeds( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Converts the Lora from multiple pytorch .pt files into a single safetensor file. + + modelmap is a dictionary of { + "module name": (pytorch_model_path, target_replace_module, rank) + } + """ + + weights = {} + metadata = {} + + for name, (path, target_replace_module, r) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + lora = torch.load(path) + for i, weight in enumerate(lora): + is_up = i % 2 == 0 + i = i // 2 + + if is_up: + metadata[f"{name}:{i}:rank"] = str(r) + weights[f"{name}:{i}:up"] = weight + else: + weights[f"{name}:{i}:down"] = weight + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def convert_loras_to_safeloras( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + outpath="./lora.safetensors", +): + convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def parse_safeloras( + safeloras, +) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: + """ + Converts a loaded safetensor file that contains a set of module Loras + into Parameters and other information + + Output is a dictionary of { + "module name": ( + [list of weights], + [list of ranks], + target_replacement_modules + ) + } + """ + loras = {} + metadata = safeloras.metadata() + + get_name = lambda k: k.split(":")[0] + + keys = list(safeloras.keys()) + keys.sort(key=get_name) + + for name, module_keys in groupby(keys, get_name): + info = metadata.get(name) + + if not info: + raise ValueError( + f"Tensor {name} has no metadata - is this a Lora safetensor?" + ) + + # Skip Textual Inversion embeds + if info == EMBED_FLAG: + continue + + # Handle Loras + # Extract the targets + target = json.loads(info) + + # Build the result lists - Python needs us to preallocate lists to insert into them + module_keys = list(module_keys) + ranks = [4] * (len(module_keys) // 2) + weights = [None] * len(module_keys) + + for key in module_keys: + # Split the model name and index out of the key + _, idx, direction = key.split(":") + idx = int(idx) + + # Add the rank + ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) + + # Insert the weight into the list + idx = idx * 2 + (1 if direction == "down" else 0) + weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key)) + + loras[name] = (weights, ranks, target) + + return loras + + +def parse_safeloras_embeds( + safeloras, +) -> Dict[str, torch.Tensor]: + """ + Converts a loaded safetensor file that contains Textual Inversion embeds into + a dictionary of embed_token: Tensor + """ + embeds = {} + metadata = safeloras.metadata() + + for key in safeloras.keys(): + # Only handle Textual Inversion embeds + meta = metadata.get(key) + if not meta or meta != EMBED_FLAG: + continue + + embeds[key] = safeloras.get_tensor(key) + + return embeds + + +def load_safeloras(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras) + + +def load_safeloras_embeds(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras_embeds(safeloras) + + +def load_safeloras_both(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras) + + +def collapse_lora( + model, + replace_modules=UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE, + alpha=1.0, +): + + search_class = [LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d] + for _module, name, _child_module in _find_modules( + model, replace_modules, search_class=search_class + ): + + if isinstance(_child_module, LoraInjectedLinear): + print("Collapsing Lin Lora in", name) + + _child_module.linear.weight = nn.Parameter( + _child_module.linear.weight.data + + alpha + * ( + _child_module.lora_up.weight.data + @ _child_module.lora_down.weight.data + ) + .type(_child_module.linear.weight.dtype) + .to(_child_module.linear.weight.device) + ) + + else: + print("Collapsing Conv Lora in", name) + _child_module.conv.weight = nn.Parameter( + _child_module.conv.weight.data + + alpha + * ( + _child_module.lora_up.weight.data.flatten(start_dim=1) + @ _child_module.lora_down.weight.data.flatten(start_dim=1) + ) + .reshape(_child_module.conv.weight.data.shape) + .type(_child_module.conv.weight.dtype) + .to(_child_module.conv.weight.device) + ) + + +def monkeypatch_or_replace_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] + ): + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_lora_extended( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, + target_replace_module, + search_class=[ + nn.Linear, + nn.Conv2d, + nn.Conv3d, + LoraInjectedLinear, + LoraInjectedConv2d, + LoraInjectedConv3d, + ], + ): + + if (_child_module.__class__ == nn.Linear) or ( + _child_module.__class__ == LoraInjectedLinear + ): + if len(loras[0].shape) != 2: + continue + + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + elif (_child_module.__class__ == nn.Conv2d) or ( + _child_module.__class__ == LoraInjectedConv2d + ): + if len(loras[0].shape) != 4: + continue + _source = ( + _child_module.conv + if isinstance(_child_module, LoraInjectedConv2d) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedConv2d( + _source.in_channels, + _source.out_channels, + _source.kernel_size, + _source.stride, + _source.padding, + _source.dilation, + _source.groups, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + + _tmp.conv.weight = weight + + if bias is not None: + _tmp.conv.bias = bias + + elif _child_module.__class__ == nn.Conv3d or ( + _child_module.__class__ == LoraInjectedConv3d + ): + + if len(loras[0].shape) != 5: + continue + + _source = ( + _child_module.conv + if isinstance(_child_module, LoraInjectedConv3d) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedConv3d( + _source.in_channels, + _source.out_channels, + bias=_source.bias is not None, + kernel_size=_source.kernel_size, + padding=_source.padding, + r=r.pop(0) if isinstance(r, list) else r, + ) + + _tmp.conv.weight = weight + + if bias is not None: + _tmp.conv.bias = bias + else: + # ignore module which are not included in search_class + # For example: + # zeroscope_v2_576w model, which has and + continue + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_safeloras(models, safeloras): + loras = parse_safeloras(safeloras) + + for name, (lora, ranks, target) in loras.items(): + model = getattr(models, name, None) + + if not model: + print(f"No model provided for {name}, contained in Lora") + continue + + monkeypatch_or_replace_lora_extended(model, lora, target, ranks) + + +def monkeypatch_remove_lora(model): + for _module, name, _child_module in _find_modules( + model, search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d] + ): + if isinstance(_child_module, LoraInjectedLinear): + _source = _child_module.linear + weight, bias = _source.weight, _source.bias + + _tmp = nn.Linear( + _source.in_features, _source.out_features, bias is not None + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + else: + _source = _child_module.conv + weight, bias = _source.weight, _source.bias + + if isinstance(_source, nn.Conv2d): + _tmp = nn.Conv2d( + in_channels=_source.in_channels, + out_channels=_source.out_channels, + kernel_size=_source.kernel_size, + stride=_source.stride, + padding=_source.padding, + dilation=_source.dilation, + groups=_source.groups, + bias=bias is not None, + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + if isinstance(_source, nn.Conv3d): + _tmp = nn.Conv3d( + _source.in_channels, + _source.out_channels, + bias=_source.bias is not None, + kernel_size=_source.kernel_size, + padding=_source.padding, + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + _module._modules[name] = _tmp + + +def monkeypatch_add_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + alpha: float = 1.0, + beta: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[LoraInjectedLinear] + ): + weight = _child_module.linear.weight + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_up.weight.to(weight.device) * beta + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_down.weight.to(weight.device) * beta + ) + + _module._modules[name].to(weight.device) + + +def tune_lora_scale(model, alpha: float = 1.0): + for _module in model.modules(): + if _module.__class__.__name__ in [ + "LoraInjectedLinear", + "LoraInjectedConv2d", + "LoraInjectedConv3d", + ]: + _module.scale = alpha + + +def set_lora_diag(model, diag: torch.Tensor): + for _module in model.modules(): + if _module.__class__.__name__ in [ + "LoraInjectedLinear", + "LoraInjectedConv2d", + "LoraInjectedConv3d", + ]: + _module.set_selector_from_diag(diag) + + +def _text_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) + + +def _ti_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["ti", "pt"]) + + +def apply_learned_embed_in_clip( + learned_embeds, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + if isinstance(token, str): + trained_tokens = [token] + elif isinstance(token, list): + assert len(learned_embeds.keys()) == len( + token + ), "The number of tokens and the number of embeds should be the same" + trained_tokens = token + else: + trained_tokens = list(learned_embeds.keys()) + + for token in trained_tokens: + print(token) + embeds = learned_embeds[token] + + # cast to dtype of text_encoder + dtype = text_encoder.get_input_embeddings().weight.dtype + num_added_tokens = tokenizer.add_tokens(token) + + i = 1 + if not idempotent: + while num_added_tokens == 0: + print(f"The tokenizer already contains the token {token}.") + token = f"{token[:-1]}-{i}>" + print(f"Attempting to add the token {token}.") + num_added_tokens = tokenizer.add_tokens(token) + i += 1 + elif num_added_tokens == 0 and idempotent: + print(f"The tokenizer already contains the token {token}.") + print(f"Replacing {token} embedding.") + + # resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds + return token + + +def load_learned_embed_in_clip( + learned_embeds_path, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + learned_embeds = torch.load(learned_embeds_path) + apply_learned_embed_in_clip( + learned_embeds, text_encoder, tokenizer, token, idempotent + ) + + +def patch_pipe( + pipe, + maybe_unet_path, + token: Optional[str] = None, + r: int = 4, + patch_unet=True, + patch_text=True, + patch_ti=True, + idempotent_token=True, + unet_target_replace_module=DEFAULT_TARGET_REPLACE, + text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, +): + if maybe_unet_path.endswith(".pt"): + # torch format + + if maybe_unet_path.endswith(".ti.pt"): + unet_path = maybe_unet_path[:-6] + ".pt" + elif maybe_unet_path.endswith(".text_encoder.pt"): + unet_path = maybe_unet_path[:-16] + ".pt" + else: + unet_path = maybe_unet_path + + ti_path = _ti_lora_path(unet_path) + text_path = _text_lora_path(unet_path) + + if patch_unet: + print("LoRA : Patching Unet") + monkeypatch_or_replace_lora( + pipe.unet, + torch.load(unet_path), + r=r, + target_replace_module=unet_target_replace_module, + ) + + if patch_text: + print("LoRA : Patching text encoder") + monkeypatch_or_replace_lora( + pipe.text_encoder, + torch.load(text_path), + target_replace_module=text_target_replace_module, + r=r, + ) + if patch_ti: + print("LoRA : Patching token input") + token = load_learned_embed_in_clip( + ti_path, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + + elif maybe_unet_path.endswith(".safetensors"): + safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu") + monkeypatch_or_replace_safeloras(pipe, safeloras) + tok_dict = parse_safeloras_embeds(safeloras) + if patch_ti: + apply_learned_embed_in_clip( + tok_dict, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + return tok_dict + + +def train_patch_pipe(pipe, patch_unet, patch_text): + if patch_unet: + print("LoRA : Patching Unet") + collapse_lora(pipe.unet) + monkeypatch_remove_lora(pipe.unet) + + if patch_text: + print("LoRA : Patching text encoder") + + collapse_lora(pipe.text_encoder) + monkeypatch_remove_lora(pipe.text_encoder) + + +@torch.no_grad() +def inspect_lora(model): + moved = {} + + for name, _module in model.named_modules(): + if _module.__class__.__name__ in [ + "LoraInjectedLinear", + "LoraInjectedConv2d", + "LoraInjectedConv3d", + ]: + ups = _module.lora_up.weight.data.clone() + downs = _module.lora_down.weight.data.clone() + + wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1) + + dist = wght.flatten().abs().mean().item() + if name in moved: + moved[name].append(dist) + else: + moved[name] = [dist] + + return moved + + +def save_all( + unet, + text_encoder, + save_path, + placeholder_token_ids=None, + placeholder_tokens=None, + save_lora=True, + save_ti=True, + target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, + target_replace_module_unet=DEFAULT_TARGET_REPLACE, + safe_form=True, +): + if not safe_form: + # save ti + if save_ti: + ti_path = _ti_lora_path(save_path) + learned_embeds_dict = {} + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + learned_embeds_dict[tok] = learned_embeds.detach().cpu() + + torch.save(learned_embeds_dict, ti_path) + print("Ti saved to ", ti_path) + + # save text encoder + if save_lora: + save_lora_weight( + unet, save_path, target_replace_module=target_replace_module_unet + ) + print("Unet saved to ", save_path) + + save_lora_weight( + text_encoder, + _text_lora_path(save_path), + target_replace_module=target_replace_module_text, + ) + print("Text Encoder saved to ", _text_lora_path(save_path)) + + else: + assert save_path.endswith( + ".safetensors" + ), f"Save path : {save_path} should end with .safetensors" + + loras = {} + embeds = {} + + if save_lora: + + loras["unet"] = (unet, target_replace_module_unet) + loras["text_encoder"] = (text_encoder, target_replace_module_text) + + if save_ti: + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + embeds[tok] = learned_embeds.detach().cpu() + + save_safeloras_with_embeds(loras, embeds, save_path) diff --git a/src/videogen_hub/pipelines/t2v_turbo/utils/lora_handler.py b/src/videogen_hub/pipelines/t2v_turbo/utils/lora_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2deb220e7781ff060f112a06def81a2e4a866a --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/utils/lora_handler.py @@ -0,0 +1,153 @@ +import torch +from types import SimpleNamespace + +from .lora import ( + extract_lora_ups_down, + inject_trainable_lora_extended, + monkeypatch_or_replace_lora_extended, +) + +CLONE_OF_SIMO_KEYS = ["model", "loras", "target_replace_module", "r"] + +lora_versions = dict(stable_lora="stable_lora", cloneofsimo="cloneofsimo") + +lora_func_types = dict(loader="loader", injector="injector") + +lora_args = dict( + model=None, + loras=None, + target_replace_module=[], + target_module=[], + r=4, + search_class=[torch.nn.Linear], + dropout=0, + lora_bias="none", +) + +LoraVersions = SimpleNamespace(**lora_versions) +LoraFuncTypes = SimpleNamespace(**lora_func_types) + +LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo] +LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector] + + +def filter_dict(_dict, keys=[]): + if len(keys) == 0: + assert "Keys cannot empty for filtering return dict." + + for k in keys: + if k not in lora_args.keys(): + assert f"{k} does not exist in available LoRA arguments" + + return {k: v for k, v in _dict.items() if k in keys} + + +class LoraHandler(object): + def __init__( + self, + version: str = LoraVersions.cloneofsimo, + use_unet_lora: bool = False, + use_text_lora: bool = False, + save_for_webui: bool = False, + only_for_webui: bool = False, + lora_bias: str = "none", + unet_replace_modules: list = ["UNet3DConditionModel"], + ): + self.version = version + assert self.is_cloneofsimo_lora() + + self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader) + self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector) + self.lora_bias = lora_bias + self.use_unet_lora = use_unet_lora + self.use_text_lora = use_text_lora + self.save_for_webui = save_for_webui + self.only_for_webui = only_for_webui + self.unet_replace_modules = unet_replace_modules + self.use_lora = any([use_text_lora, use_unet_lora]) + + if self.use_lora: + print(f"Using LoRA Version: {self.version}") + + def is_cloneofsimo_lora(self): + return self.version == LoraVersions.cloneofsimo + + def get_lora_func(self, func_type: str = LoraFuncTypes.loader): + if func_type == LoraFuncTypes.loader: + return monkeypatch_or_replace_lora_extended + + if func_type == LoraFuncTypes.injector: + return inject_trainable_lora_extended + + assert "LoRA Version does not exist." + + def get_lora_func_args( + self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias + ): + return_dict = lora_args.copy() + + return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS) + return_dict.update( + { + "model": model, + "loras": lora_path, + "target_replace_module": replace_modules, + "r": r, + } + ) + + return return_dict + + def do_lora_injection( + self, + model, + replace_modules, + bias="none", + dropout=0, + r=4, + lora_loader_args=None, + ): + REPLACE_MODULES = replace_modules + + params = None + negation = None + + injector_args = lora_loader_args + + params, negation = self.lora_injector(**injector_args) + for _up, _down in extract_lora_ups_down( + model, target_replace_module=REPLACE_MODULES + ): + + if all(x is not None for x in [_up, _down]): + print( + f"Lora successfully injected into {model.__class__.__name__}." + ) + + break + + return params, negation + + def add_lora_to_model( + self, use_lora, model, replace_modules, dropout=0.0, lora_path=None, r=16 + ): + + params = None + negation = None + + lora_loader_args = self.get_lora_func_args( + lora_path, use_lora, model, replace_modules, r, dropout, self.lora_bias + ) + + if use_lora: + params, negation = self.do_lora_injection( + model, + replace_modules, + bias=self.lora_bias, + lora_loader_args=lora_loader_args, + dropout=dropout, + r=r, + ) + + params = model if params is None else params + return params, negation diff --git a/src/videogen_hub/pipelines/t2v_turbo/utils/utils.py b/src/videogen_hub/pipelines/t2v_turbo/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4024732724fd664cf6618cbe578ecca51dbb189c --- /dev/null +++ b/src/videogen_hub/pipelines/t2v_turbo/utils/utils.py @@ -0,0 +1,141 @@ +import importlib +import os +import numpy as np +import cv2 +import torch +import torch.distributed as dist +import torchvision +import sys + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def check_istarget(name, para_list): + """ + name: full name of source para + para_list: partial name of target para + """ + istarget = False + for para in para_list: + if para in name: + return True + return istarget + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + +def get_obj_from_str(string, reload=False): + # Get the current directory + current_dir = os.path.abspath(os.path.dirname(__file__)) + + # Move up to the `t2v_turbo` directory + while os.path.basename(current_dir) not in ['t2v_turbo', 'videogen_hub']: + current_dir = os.path.dirname(current_dir) + if current_dir == os.path.dirname(current_dir): # Reached the root directory + raise FileNotFoundError("Couldn't find 't2v_turbo' or 'videogen_hub' in the path hierarchy") + + # Construct the paths for `pipelines` and `t2v_turbo` + paths_to_add = [] + if os.path.basename(current_dir) == 't2v_turbo': + paths_to_add.append(current_dir) + paths_to_add.append(os.path.join(current_dir, '..')) # Up one level to the 'pipelines' directory + elif os.path.basename(current_dir) == 'videogen_hub': + paths_to_add.append(os.path.join(current_dir, 'pipelines')) + paths_to_add.append(os.path.join(current_dir, 'pipelines', 't2v_turbo')) + + # Normalize paths to avoid issues with '..' + paths_to_add = [os.path.normpath(path) for path in paths_to_add] + + print("+++++> string", string) + print("+++++> base_dir", current_dir) + print("+++++> paths_to_add", paths_to_add) + + # Add the paths to sys.path if they're not already there + for path in paths_to_add: + if path not in sys.path: + sys.path.insert(0, path) + + # Extract the module and class names + module, cls = string.rsplit(".", 1) + + # Import and optionally reload the module + module_imp = importlib.import_module(module) + if reload: + importlib.reload(module_imp) + + # Get the class from the module + return getattr(module_imp, cls) + +""" +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) +""" + +def load_npz_from_dir(data_dir): + data = [ + np.load(os.path.join(data_dir, data_name))["arr_0"] + for data_name in os.listdir(data_dir) + ] + data = np.concatenate(data, axis=0) + return data + + +def load_npz_from_paths(data_paths): + data = [np.load(data_path)["arr_0"] for data_path in data_paths] + data = np.concatenate(data, axis=0) + return data + + +def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): + h, w = image.shape[:2] + if resize_short_edge is not None: + k = resize_short_edge / min(h, w) + else: + k = max_resolution / (h * w) + k = k**0.5 + h = int(np.round(h * k / 64)) * 64 + w = int(np.round(w * k / 64)) * 64 + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) + return image + + +def setup_dist(args): + if dist.is_initialized(): + return + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group("nccl", init_method="env://") + + +def save_videos(batch_tensors, savedir, filenames, fps=16): + # b,samples,c,t,h,w + n_samples = batch_tensors.shape[1] + for idx, vid_tensor in enumerate(batch_tensors): + video = vid_tensor.detach().cpu() + video = torch.clamp(video.float(), -1.0, 1.0) + video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) + for framesheet in video + ] # [3, 1*h, n*w] + grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + savepath = os.path.join(savedir, f"{filenames[idx]}.mp4") + torchvision.io.write_video( + savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"} + ) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/videocrafter/__init__.py b/src/videogen_hub/pipelines/videocrafter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67768467a77a292302ceae48502ee3419e761718 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/__init__.py @@ -0,0 +1,7 @@ +import sys +from pathlib import Path +import os + +cur_path = str(Path(__file__).parent.absolute()) +sys.path.insert(0, cur_path) +sys.path.insert(0, os.path.join(cur_path, 'lvdm')) diff --git a/src/videogen_hub/pipelines/videocrafter/funcs.py b/src/videogen_hub/pipelines/videocrafter/funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..52d7e1f58facedef431f8473c3dc1dd15224eb67 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/funcs.py @@ -0,0 +1,200 @@ +import os, sys, glob +import numpy as np +from collections import OrderedDict +from decord import VideoReader, cpu +import cv2 + +import torch +import torchvision + +sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) +from .lvdm.models.samplers.ddim import DDIMSampler + + +def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0, + cfg_scale=1.0, temporal_cfg_scale=None, **kwargs): + ddim_sampler = DDIMSampler(model) + uncond_type = model.uncond_type + batch_size = noise_shape[0] + + ## construct unconditional guidance + if cfg_scale != 1.0: + if uncond_type == "empty_seq": + prompts = batch_size * [""] + # prompts = N * T * [""] ## if is_imgbatch=True + uc_emb = model.get_learned_conditioning(prompts) + elif uncond_type == "zero_embed": + c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond + uc_emb = torch.zeros_like(c_emb) + + ## process image embedding token + if hasattr(model, 'embedder'): + uc_img = torch.zeros(noise_shape[0], 3, 224, 224).to(model.device) + ## img: b c h w >> b l c + uc_img = model.get_image_embeds(uc_img) + uc_emb = torch.cat([uc_emb, uc_img], dim=1) + + if isinstance(cond, dict): + uc = {key: cond[key] for key in cond.keys()} + uc.update({'c_crossattn': [uc_emb]}) + else: + uc = uc_emb + else: + uc = None + + x_T = None + batch_variants = [] + # batch_variants1, batch_variants2 = [], [] + for _ in range(n_samples): + if ddim_sampler is not None: + kwargs.update({"clean_cond": True}) + samples, _ = ddim_sampler.sample(S=ddim_steps, + conditioning=cond, + batch_size=noise_shape[0], + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + temporal_length=noise_shape[2], + conditional_guidance_scale_temporal=temporal_cfg_scale, + x_T=x_T, + **kwargs + ) + ## reconstruct from latent to pixel space + batch_images = model.decode_first_stage_2DAE(samples) + batch_variants.append(batch_images) + ## batch, , c, t, h, w + batch_variants = torch.stack(batch_variants, dim=1) + return batch_variants + + +def get_filelist(data_dir, ext='*'): + file_list = glob.glob(os.path.join(data_dir, '*.%s' % ext)) + file_list.sort() + return file_list + + +def get_dirlist(path): + list = [] + if (os.path.exists(path)): + files = os.listdir(path) + for file in files: + m = os.path.join(path, file) + if (os.path.isdir(m)): + list.append(m) + list.sort() + return list + + +def load_model_checkpoint(model, ckpt): + def load_checkpoint(model, ckpt, full_strict): + state_dict = torch.load(ckpt, map_location="cpu") + try: + ## deepspeed + new_pl_sd = OrderedDict() + for key in state_dict['module'].keys(): + new_pl_sd[key[16:]] = state_dict['module'][key] + model.load_state_dict(new_pl_sd, strict=full_strict) + except: + if "state_dict" in list(state_dict.keys()): + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict, strict=full_strict) + return model + + load_checkpoint(model, ckpt, full_strict=True) + print('>>> model checkpoint loaded.') + return model + + +def load_prompts(prompt_file): + f = open(prompt_file, 'r') + prompt_list = [] + for idx, line in enumerate(f.readlines()): + l = line.strip() + if len(l) != 0: + prompt_list.append(l) + f.close() + return prompt_list + + +def load_video_batch(filepath_list, frame_stride, video_size=(256, 256), video_frames=16): + ''' + Notice about some special cases: + 1. video_frames=-1 means to take all the frames (with fs=1) + 2. when the total video frames is less than required, padding strategy will be used (repreated last frame) + ''' + fps_list = [] + batch_tensor = [] + assert frame_stride > 0, "valid frame stride should be a positive interge!" + for filepath in filepath_list: + padding_num = 0 + vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0]) + fps = vidreader.get_avg_fps() + total_frames = len(vidreader) + max_valid_frames = (total_frames - 1) // frame_stride + 1 + if video_frames < 0: + ## all frames are collected: fs=1 is a must + required_frames = total_frames + frame_stride = 1 + else: + required_frames = video_frames + query_frames = min(required_frames, max_valid_frames) + frame_indices = [frame_stride * i for i in range(query_frames)] + + ## [t,h,w,c] -> [c,t,h,w] + frames = vidreader.get_batch(frame_indices) + frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() + frame_tensor = (frame_tensor / 255. - 0.5) * 2 + if max_valid_frames < required_frames: + padding_num = required_frames - max_valid_frames + frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:, -1:, :, :]] * padding_num)], dim=1) + print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.') + batch_tensor.append(frame_tensor) + sample_fps = int(fps / frame_stride) + fps_list.append(sample_fps) + + return torch.stack(batch_tensor, dim=0) + + +from PIL import Image + + +def load_image_batch(filepath_list, image_size=(256, 256)): + batch_tensor = [] + for filepath in filepath_list: + _, filename = os.path.split(filepath) + _, ext = os.path.splitext(filename) + if ext == '.mp4': + vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0]) + frame = vidreader.get_batch([0]) + img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float() + elif ext == '.png' or ext == '.jpg': + img = Image.open(filepath).convert("RGB") + rgb_img = np.array(img, np.float32) + # bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR) + # bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + rgb_img = cv2.resize(rgb_img, (image_size[1], image_size[0]), interpolation=cv2.INTER_LINEAR) + img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float() + else: + print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]') + raise NotImplementedError + img_tensor = (img_tensor / 255. - 0.5) * 2 + batch_tensor.append(img_tensor) + return torch.stack(batch_tensor, dim=0) + + +def save_videos(batch_tensors, savedir, filenames, fps=10): + # b,samples,c,t,h,w + n_samples = batch_tensors.shape[1] + for idx, vid_tensor in enumerate(batch_tensors): + video = vid_tensor.detach().cpu() + video = torch.clamp(video.float(), -1., 1.) + video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w + frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in + video] # [3, 1*h, n*w] + grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + savepath = os.path.join(savedir, f"{filenames[idx]}.mp4") + torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) diff --git a/src/videogen_hub/pipelines/videocrafter/inference.py b/src/videogen_hub/pipelines/videocrafter/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..efd49282c0e8d2df7eaa791e22e09fd3f034fe4c --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/inference.py @@ -0,0 +1,160 @@ +import argparse, os, sys, glob, yaml, math, random +import datetime, time +import numpy as np +from omegaconf import OmegaConf +from collections import OrderedDict +from tqdm import trange, tqdm +from einops import repeat +from einops import rearrange, repeat +from functools import partial +import torch +from pytorch_lightning import seed_everything + +from .funcs import load_model_checkpoint, load_image_batch, get_filelist, save_videos +from .funcs import batch_ddim_sampling +from .utils import instantiate_from_config + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=20230211, help="seed for seed_everything") + parser.add_argument("--mode", default="base", type=str, help="which kind of inference mode: {'base', 'i2v'}") + parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path") + parser.add_argument("--config", type=str, help="config (yaml) path") + parser.add_argument("--savefps", type=str, default=10, help="video fps to generate") + parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt", ) + parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM", ) + parser.add_argument("--ddim_eta", type=float, default=1.0, + help="eta for ddim sampling (0.0 yields deterministic sampling)", ) + parser.add_argument("--bs", type=int, default=1, help="batch size for inference") + parser.add_argument("--height", type=int, default=512, help="image height, in pixel space") + parser.add_argument("--width", type=int, default=512, help="image width, in pixel space") + parser.add_argument("--frames", type=int, default=-1, help="frames num to inference") + parser.add_argument("--fps", type=int, default=24) + parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, + help="prompt classifier-free guidance") + parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, + help="temporal consistency guidance") + ## for conditional i2v only + # parser.add_argument("--cond_input", type=str, default=None, help="data dir of conditional input") + return parser + + +class VideoCrafterPipeline(): + def __init__(self, arg_list, device, rank: int = 0, gpu_num: int = 1): + """ + Initialize the pipeline of videocrafter. + It is always on one GPU. + Args: + arg_list: The parameters needed for the model. + device: + rank: + gpu_num: + """ + parser = get_parser() + self.args = parser.parse_args(args=arg_list) + + self.gpu_no, self.gpu_num = rank, gpu_num + _dict = {'model': {'target': 'lvdm.models.ddpm3d.LatentDiffusion', 'params': {'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'timesteps': 1000, 'first_stage_key': 'video', 'cond_stage_key': 'caption', 'cond_stage_trainable': False, 'conditioning_key': 'crossattn', 'image_size': [40, 64], 'channels': 4, 'scale_by_std': False, 'scale_factor': 0.18215, 'use_ema': False, 'uncond_type': 'empty_seq', 'use_scale': True, 'scale_b': 0.7, 'unet_config': {'target': 'lvdm.modules.networks.openaimodel3d.UNetModel', 'params': {'in_channels': 4, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_head_channels': 64, 'transformer_depth': 1, 'context_dim': 1024, 'use_linear': True, 'use_checkpoint': True, 'temporal_conv': True, 'temporal_attention': True, 'temporal_selfatt_only': True, 'use_relative_position': False, 'use_causal_attention': False, 'temporal_length': 16, 'addition_attention': True, 'fps_cond': True}}, 'first_stage_config': {'target': 'lvdm.models.autoencoder.AutoencoderKL', 'params': {'embed_dim': 4, 'monitor': 'val/rec_loss', 'ddconfig': {'double_z': True, 'z_channels': 4, 'resolution': 512, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}, 'lossconfig': {'target': 'torch.nn.Identity'}}}, 'cond_stage_config': {'target': 'lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder', 'params': {'freeze': True, 'layer': 'penultimate'}}}}} + + config = OmegaConf.create(_dict) + #config = OmegaConf.load(self.args.config) + + # data_config = config.pop("data", OmegaConf.create()) + model_config = config.pop("model", OmegaConf.create()) + model = instantiate_from_config(model_config) + model = model.cuda(self.gpu_no) + print("About to load model") + assert os.path.exists(self.args.ckpt_path), f"Error: checkpoint [{self.args.ckpt_path}] Not Found!" + self.model = load_model_checkpoint(model, self.args.ckpt_path) + self.model.eval() + + def run_inference(self, prompt, video_length, height, width, **kwargs): + """ + https://github.com/AILab-CVC/VideoCrafter + Generate video from the provided text prompt. + Args: + prompt: The provided text prompt. + video_length: The length (num of frames) of the generated video. + height: The height of the video frame. + width: The width of the video frame. + **kwargs: + + Returns: + The generated video represented as tensor with shape (1, 1, channels, height, width, num of frames) + + """ + ## step 1: model config + ## ----------------------------------------------------------------- + ## sample shape + assert (self.args.height % 16 == 0) and ( + self.args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" + ## latent noise shape + h, w = height // 8, width // 8 + frames = video_length + channels = self.model.channels + + ## step 2: load data + ## ----------------------------------------------------------------- + prompt_list = [prompt] + num_samples = len(prompt_list) + # filename_list = [f"{id + 1:04d}" for id in range(num_samples)] + + gpu_num = self.gpu_num + gpu_no = self.gpu_no + samples_split = num_samples // gpu_num + residual_tail = num_samples % gpu_num + print(f'[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.') + indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1))) + if gpu_no == 0 and residual_tail != 0: + indices = indices + list(range(num_samples - residual_tail, num_samples)) + prompt_list_rank = [prompt_list[i] for i in indices] + + # # conditional input + # if self.args.mode == "i2v": + # ## each video or frames dir per prompt + # cond_inputs = get_filelist(self.args.cond_input, ext='[mpj][pn][4gj]') # '[mpj][pn][4gj]' + # assert len( + # cond_inputs) == num_samples, f"Error: conditional input ({len(cond_inputs)}) NOT match prompt ({num_samples})!" + # filename_list = [f"{os.path.split(cond_inputs[id])[-1][:-4]}" for id in range(num_samples)] + # cond_inputs_rank = [cond_inputs[i] for i in indices] + + # filename_list_rank = [filename_list[i] for i in indices] + + ## step 3: run over samples + ## ----------------------------------------------------------------- + # start = time.time() + n_rounds = len(prompt_list_rank) // self.args.bs + n_rounds = n_rounds + 1 if len(prompt_list_rank) % self.args.bs != 0 else n_rounds + for idx in range(0, n_rounds): + print(f'[rank:{gpu_no}] batch-{idx + 1} ({self.args.bs})x{self.args.n_samples} ...') + idx_s = idx * self.args.bs + idx_e = min(idx_s + self.args.bs, len(prompt_list_rank)) + batch_size = idx_e - idx_s + # filenames = filename_list_rank[idx_s:idx_e] + noise_shape = [batch_size, channels, frames, h, w] + fps = torch.tensor([self.args.fps] * batch_size).to(self.model.device).long() + + prompts = prompt_list_rank[idx_s:idx_e] + if isinstance(prompts, str): + prompts = [prompts] + # prompts = batch_size * [""] + text_emb = self.model.get_learned_conditioning(prompts) + + if self.args.mode == 'base': + cond = {"c_crossattn": [text_emb], "fps": fps} + # elif self.args.mode == 'i2v': + # # cond_images = torch.zeros(noise_shape[0],3,224,224).to(model.device) + # cond_images = load_image_batch(cond_inputs_rank[idx_s:idx_e], (self.args.height, self.args.width)) + # cond_images = cond_images.to(self.model.device) + # img_emb = self.model.get_image_embeds(cond_images) + # imtext_cond = torch.cat([text_emb, img_emb], dim=1) + # cond = {"c_crossattn": [imtext_cond], "fps": fps} + else: + raise NotImplementedError + + ## inference + batch_samples = batch_ddim_sampling(self.model, cond, noise_shape, self.args.n_samples, + self.args.ddim_steps, + self.args.ddim_eta, + self.args.unconditional_guidance_scale, **kwargs) + return batch_samples diff --git a/src/videogen_hub/pipelines/videocrafter/inference_t2v_512_v2.0.yaml b/src/videogen_hub/pipelines/videocrafter/inference_t2v_512_v2.0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a2e6c4e88ad32e439bb6b95b1a200d0ef104603 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/inference_t2v_512_v2.0.yaml @@ -0,0 +1,77 @@ +model: + target: lvdm.models.ddpm3d.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + timesteps: 1000 + first_stage_key: video + cond_stage_key: caption + cond_stage_trainable: false + conditioning_key: crossattn + image_size: + - 40 + - 64 + channels: 4 + scale_by_std: false + scale_factor: 0.18215 + use_ema: false + uncond_type: empty_seq + use_scale: true + scale_b: 0.7 + unet_config: + target: lvdm.modules.networks.openaimodel3d.UNetModel + params: + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_head_channels: 64 + transformer_depth: 1 + context_dim: 1024 + use_linear: true + use_checkpoint: true + temporal_conv: true + temporal_attention: true + temporal_selfatt_only: true + use_relative_position: false + use_causal_attention: false + temporal_length: 16 + addition_attention: true + fps_cond: true + first_stage_config: + target: lvdm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 512 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder + params: + freeze: true + layer: penultimate diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/__init__.py b/src/videogen_hub/pipelines/videocrafter/lvdm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/basics.py b/src/videogen_hub/pipelines/videocrafter/lvdm/basics.py new file mode 100644 index 0000000000000000000000000000000000000000..41805983551873d1112de353bcbc9fd50f0ca861 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/basics.py @@ -0,0 +1,100 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +import torch.nn as nn +from utils import instantiate_from_config + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def nonlinearity(type='silu'): + if type == 'silu': + return nn.SiLU() + elif type == 'leaky_relu': + return nn.LeakyReLU() + + +class GroupNormSpecific(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels, num_groups=32): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNormSpecific(num_groups, channels) + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} \ No newline at end of file diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/common.py b/src/videogen_hub/pipelines/videocrafter/lvdm/common.py new file mode 100644 index 0000000000000000000000000000000000000000..35569b25aa97236d7d083d8b6ef0c0f3187c2388 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/common.py @@ -0,0 +1,95 @@ +import math +from inspect import isfunction +import torch +from torch import nn +import torch.distributed as dist + + +def gather_data(data, return_np=True): + ''' gather data from multiple processes to one list ''' + data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] + dist.all_gather(data_list, data) # gather not supported with NCCL + if return_np: + data_list = [data.cpu().numpy() for data in data_list] + return data_list + +def autocast(f): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast(enabled=True, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled()): + return f(*args, **kwargs) + return do_autocast + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val): + return val is not None + +def identity(*args, **kwargs): + return nn.Identity() + +def uniq(arr): + return{el: True for el in arr}.keys() + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + +def shape_to_str(x): + shape_str = "x".join([str(x) for x in x.shape]) + return shape_str + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + +ckpt = torch.utils.checkpoint.checkpoint +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + return ckpt(func, *inputs) + else: + return func(*inputs) + diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/distributions.py b/src/videogen_hub/pipelines/videocrafter/lvdm/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..0b69b6984880ec24279b658384ed8031335e3474 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/distributions.py @@ -0,0 +1,95 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self, noise=None): + if noise is None: + noise = torch.randn(self.mean.shape) + + x = self.mean + self.std * noise.to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/ema.py b/src/videogen_hub/pipelines/videocrafter/lvdm/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c75af43565f6e140287644aaaefa97dd6e67c5 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/models/__init__.py b/src/videogen_hub/pipelines/videocrafter/lvdm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/models/autoencoder.py b/src/videogen_hub/pipelines/videocrafter/lvdm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8a40ebd73e1b6016bc7b3bfe35ca97a005155b76 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/models/autoencoder.py @@ -0,0 +1,219 @@ +import os +from contextlib import contextmanager +import torch +import numpy as np +from einops import rearrange +import torch.nn.functional as F +import pytorch_lightning as pl +from lvdm.modules.networks.ae_modules import Encoder, Decoder +from lvdm.distributions import DiagonalGaussianDistribution +from utils import instantiate_from_config + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + test=False, + logdir=None, + input_dim=4, + test_args=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + self.input_dim = input_dim + self.test = test + self.test_args = test_args + self.logdir = logdir + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + if self.test: + self.init_test() + + def init_test(self,): + self.test = True + save_dir = os.path.join(self.logdir, "test") + if 'ckpt' in self.test_args: + ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}' + self.root = os.path.join(save_dir, ckpt_name) + else: + self.root = save_dir + if 'test_subdir' in self.test_args: + self.root = os.path.join(save_dir, self.test_args.test_subdir) + + self.root_zs = os.path.join(self.root, "zs") + self.root_dec = os.path.join(self.root, "reconstructions") + self.root_inputs = os.path.join(self.root, "inputs") + os.makedirs(self.root, exist_ok=True) + + if self.test_args.save_z: + os.makedirs(self.root_zs, exist_ok=True) + if self.test_args.save_reconstruction: + os.makedirs(self.root_dec, exist_ok=True) + if self.test_args.save_input: + os.makedirs(self.root_inputs, exist_ok=True) + assert(self.test_args is not None) + self.test_maximum = getattr(self.test_args, 'test_maximum', None) + self.count = 0 + self.eval_metrics = {} + self.decodes = [] + self.save_decode_samples = 2048 + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu") + try: + self._cur_epoch = sd['epoch'] + sd = sd["state_dict"] + except: + self._cur_epoch = 'null' + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + # self.load_state_dict(sd, strict=True) + print(f"Restored from {path}") + + def encode(self, x, **kwargs): + + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, **kwargs): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if x.dim() == 5 and self.input_dim == 4: + b,c,t,h,w = x.shape + self.b = b + self.t = t + x = rearrange(x, 'b c t h w -> (b t) c h w') + + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/models/ddpm3d.py b/src/videogen_hub/pipelines/videocrafter/lvdm/models/ddpm3d.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd37ced582cf3bf142095089eb931c8be8c9e59 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/models/ddpm3d.py @@ -0,0 +1,763 @@ +""" +wild mixture of +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +from functools import partial +from contextlib import contextmanager +import numpy as np +from tqdm import tqdm +from einops import rearrange, repeat +import logging +mainlogger = logging.getLogger('mainlogger') +import torch +import torch.nn as nn +from torchvision.utils import make_grid +import pytorch_lightning as pl +from utils import instantiate_from_config +from lvdm.ema import LitEma +from lvdm.distributions import DiagonalGaussianDistribution +from lvdm.models.utils_diffusion import make_beta_schedule +from lvdm.modules.encoders.ip_resampler import ImageProjModel, Resampler +from lvdm.basics import disabled_train +from lvdm.common import ( + extract_into_tensor, + noise_like, + exists, + default +) + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor=None, + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0. + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + mainlogger.info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.channels = channels + self.temporal_length = unet_config.params.temporal_length + self.image_size = image_size + if isinstance(self.image_size, int): + self.image_size = [self.image_size, self.image_size] + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + mainlogger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + mainlogger.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + mainlogger.info(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + mainlogger.info("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + mainlogger.info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + mainlogger.info(f"Missing Keys: {missing}") + if len(unexpected) > 0: + mainlogger.info(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start * + extract_into_tensor(self.scale_arr, t, x_start.shape) + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_input(self, batch, k): + x = batch[k] + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="caption", + cond_stage_trainable=False, + cond_stage_forward=None, + conditioning_key=None, + uncond_prob=0.2, + uncond_type="empty_seq", + scale_factor=1.0, + scale_by_std=False, + encoder_type="2d", + only_model=False, + use_scale=False, + scale_a=1, + scale_b=0.3, + mid_step=400, + fix_scale_bug=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + conditioning_key = default(conditioning_key, 'crossattn') + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + + # scale factor + self.use_scale=use_scale + if self.use_scale: + self.scale_a=scale_a + self.scale_b=scale_b + if fix_scale_bug: + scale_step=self.num_timesteps-mid_step + else: #bug + scale_step = self.num_timesteps + + scale_arr1 = np.linspace(scale_a, scale_b, mid_step) + scale_arr2 = np.full(scale_step, scale_b) + scale_arr = np.concatenate((scale_arr1, scale_arr2)) + scale_arr_prev = np.append(scale_a, scale_arr[:-1]) + to_torch = partial(torch.tensor, dtype=torch.float32) + self.register_buffer('scale_arr', to_torch(scale_arr)) + + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.first_stage_config = first_stage_config + self.cond_stage_config = cond_stage_config + self.clip_denoised = False + + self.cond_stage_forward = cond_stage_forward + self.encoder_type = encoder_type + assert(encoder_type in ["2d", "3d"]) + self.uncond_prob = uncond_prob + self.classifier_free_guidance = True if uncond_prob > 0 else False + assert(uncond_type in ["zero_embed", "empty_seq"]) + self.uncond_type = uncond_type + + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model) + self.restarted_from_ckpt = True + + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + if self.use_scale: + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start * + extract_into_tensor(self.scale_arr, t, x_start.shape) + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + else: + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + + def _freeze_model(self): + for name, para in self.model.diffusion_model.named_parameters(): + para.requires_grad = False + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + model = instantiate_from_config(config) + self.cond_stage_model = model + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def get_first_stage_encoding(self, encoder_posterior, noise=None): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample(noise=noise) + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + @torch.no_grad() + def encode_first_stage(self, x): + if self.encoder_type == "2d" and x.dim() == 5: + b, _, t, _, _ = x.shape + x = rearrange(x, 'b c t h w -> (b t) c h w') + reshape_back = True + else: + reshape_back = False + + encoder_posterior = self.first_stage_model.encode(x) + results = self.get_first_stage_encoding(encoder_posterior).detach() + + if reshape_back: + results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t) + + return results + + @torch.no_grad() + def encode_first_stage_2DAE(self, x): + + b, _, t, _, _ = x.shape + results = torch.cat([self.get_first_stage_encoding(self.first_stage_model.encode(x[:,:,i])).detach().unsqueeze(2) for i in range(t)], dim=2) + + return results + + def decode_core(self, z, **kwargs): + if self.encoder_type == "2d" and z.dim() == 5: + b, _, t, _, _ = z.shape + z = rearrange(z, 'b c t h w -> (b t) c h w') + reshape_back = True + else: + reshape_back = False + + z = 1. / self.scale_factor * z + + results = self.first_stage_model.decode(z, **kwargs) + + if reshape_back: + results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t) + return results + + @torch.no_grad() + def decode_first_stage(self, z, **kwargs): + return self.decode_core(z, **kwargs) + + def apply_model(self, x_noisy, t, cond, **kwargs): + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond, **kwargs) + + if isinstance(x_recon, tuple): + return x_recon[0] + else: + return x_recon + + def _get_denoise_row_from_list(self, samples, desc=''): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device))) + n_log_timesteps = len(denoise_row) + + denoise_row = torch.stack(denoise_row) # n_log_timesteps, b, C, H, W + + if denoise_row.dim() == 5: + # img, num_imgs= n_log_timesteps * bs, grid_size=[bs,n_log_timesteps] + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_log_timesteps) + elif denoise_row.dim() == 6: + # video, grid_size=[n_log_timesteps*bs, t] + video_length = denoise_row.shape[3] + denoise_grid = rearrange(denoise_row, 'n b c t h w -> b n c t h w') + denoise_grid = rearrange(denoise_grid, 'b n c t h w -> (b n) c t h w') + denoise_grid = rearrange(denoise_grid, 'n c t h w -> (n t) c h w') + denoise_grid = make_grid(denoise_grid, nrow=video_length) + else: + raise ValueError + + return denoise_grid + + + @torch.no_grad() + def decode_first_stage_2DAE(self, z, **kwargs): + + b, _, t, _, _ = z.shape + z = 1. / self.scale_factor * z + results = torch.cat([self.first_stage_model.decode(z[:,:,i], **kwargs).unsqueeze(2) for i in range(t)], dim=2) + + return results + + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_x0=False, score_corrector=None, corrector_kwargs=None, **kwargs): + t_in = t + model_out = self.apply_model(x, t_in, c, **kwargs) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + + if return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, return_x0=False, \ + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, **kwargs): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, return_x0=return_x0, \ + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, **kwargs) + if return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, x_T=None, verbose=True, callback=None, \ + timesteps=None, mask=None, x0=None, img_callback=None, start_T=None, log_every_t=None, **kwargs): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + # sample an initial noise + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + if start_T is not None: + timesteps = min(timesteps, start_T) + + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, **kwargs) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + +class LatentVisualDiffusion(LatentDiffusion): + def __init__(self, cond_img_config, finegrained=False, random_cond=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.random_cond = random_cond + self.instantiate_img_embedder(cond_img_config, freeze=True) + num_tokens = 16 if finegrained else 4 + self.image_proj_model = self.init_projector(use_finegrained=finegrained, num_tokens=num_tokens, input_dim=1024,\ + cross_attention_dim=1024, dim=1280) + + def instantiate_img_embedder(self, config, freeze=True): + embedder = instantiate_from_config(config) + if freeze: + self.embedder = embedder.eval() + self.embedder.train = disabled_train + for param in self.embedder.parameters(): + param.requires_grad = False + + def init_projector(self, use_finegrained, num_tokens, input_dim, cross_attention_dim, dim): + if not use_finegrained: + image_proj_model = ImageProjModel(clip_extra_context_tokens=num_tokens, cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=input_dim + ) + else: + image_proj_model = Resampler(dim=input_dim, depth=4, dim_head=64, heads=12, num_queries=num_tokens, + embedding_dim=dim, output_dim=cross_attention_dim, ff_mult=4 + ) + return image_proj_model + + ## Never delete this func: it is used in log_images() and inference stage + def get_image_embeds(self, batch_imgs): + ## img: b c h w + img_token = self.embedder(batch_imgs) + img_emb = self.image_proj_model(img_token) + return img_emb + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, + c_adm=None, s=None, mask=None, **kwargs): + # temporal_context = fps is foNone + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t, **kwargs) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, **kwargs) + elif self.conditioning_key == 'hybrid': + ## it is just right [b,c,t,h,w]: concatenate in channel dim + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'resblockcond': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'hybrid-time': + assert s is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, s=s) + elif self.conditioning_key == 'concat-time-mask': + # assert s is not None + # mainlogger.info('x & mask:',x.shape,c_concat[0].shape) + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t, context=None, s=s, mask=mask) + elif self.conditioning_key == 'concat-adm-mask': + # assert s is not None + # mainlogger.info('x & mask:',x.shape,c_concat[0].shape) + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + out = self.diffusion_model(xc, t, context=None, y=s, mask=mask) + elif self.conditioning_key == 'hybrid-adm-mask': + cc = torch.cat(c_crossattn, 1) + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + out = self.diffusion_model(xc, t, context=cc, y=s, mask=mask) + elif self.conditioning_key == 'hybrid-time-adm': # adm means y, e.g., class index + # assert s is not None + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, s=s, y=c_adm) + else: + raise NotImplementedError() + + return out \ No newline at end of file diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/models/samplers/__init__.py b/src/videogen_hub/pipelines/videocrafter/lvdm/models/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/models/samplers/ddim.py b/src/videogen_hub/pipelines/videocrafter/lvdm/models/samplers/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..9d153a3016f6f0d69461a0f40b2af1585fdb6b92 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/models/samplers/ddim.py @@ -0,0 +1,336 @@ +import numpy as np +from tqdm import tqdm +import torch +from videogen_hub.pipelines.videocrafter.lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps +from videogen_hub.pipelines.videocrafter.lvdm.common import noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.counter = 0 + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.use_scale = self.model.use_scale + print('DDIM scale', self.use_scale) + + if self.use_scale: + self.register_buffer('scale_arr', to_torch(self.model.scale_arr)) + ddim_scale_arr = self.scale_arr.cpu()[self.ddim_timesteps] + self.register_buffer('ddim_scale_arr', ddim_scale_arr) + ddim_scale_arr = np.asarray([self.scale_arr.cpu()[0]] + self.scale_arr.cpu()[self.ddim_timesteps[:-1]].tolist()) + self.register_buffer('ddim_scale_arr_prev', ddim_scale_arr) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + schedule_verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + + # check condition bs + if conditioning is not None: + if isinstance(conditioning, dict): + try: + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + except: + cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] + + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=schedule_verbose) + + # make shape + if len(shape) == 3: + C, H, W = shape + size = (batch_size, C, H, W) + elif len(shape) == 4: + C, T, H, W = shape + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + verbose=verbose, + **kwargs) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True, + cond_tau=1., target_size=None, start_timesteps=None, + **kwargs): + device = self.model.betas.device + print('ddim device', device) + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + if verbose: + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + else: + iterator = time_range + + init_x0 = False + clean_cond = kwargs.pop("clean_cond", False) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + if start_timesteps is not None: + assert x0 is not None + if step > start_timesteps*time_range[0]: + continue + elif not init_x0: + img = self.model.q_sample(x0, ts) + init_x0 = True + + # use mask to blend noised original latent (img_orig) & new sampled latent (img) + if mask is not None: + assert x0 is not None + if clean_cond: + img_orig = x0 + else: + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img # keep original & modify use img + + index_clip = int((1 - cond_tau) * total_steps) + if index <= index_clip and target_size is not None: + target_size_ = [target_size[0], target_size[1]//8, target_size[2]//8] + img = torch.nn.functional.interpolate( + img, + size=target_size_, + mode="nearest", + ) + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + x0=x0, + **kwargs) + + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + uc_type=None, conditional_guidance_scale_temporal=None, **kwargs): + b, *_, device = *x.shape, x.device + if x.dim() == 5: + is_video = True + else: + is_video = False + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c, **kwargs) # unet denoiser + else: + # with unconditional condition + if isinstance(c, torch.Tensor): + e_t = self.model.apply_model(x, t, c, **kwargs) + e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) + elif isinstance(c, dict): + e_t = self.model.apply_model(x, t, c, **kwargs) + e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) + else: + raise NotImplementedError + # text cfg + if uc_type is None: + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + else: + if uc_type == 'cfg_original': + e_t = e_t + unconditional_guidance_scale * (e_t - e_t_uncond) + elif uc_type == 'cfg_ours': + e_t = e_t + unconditional_guidance_scale * (e_t_uncond - e_t) + else: + raise NotImplementedError + # temporal guidance + if conditional_guidance_scale_temporal is not None: + e_t_temporal = self.model.apply_model(x, t, c, **kwargs) + e_t_image = self.model.apply_model(x, t, c, no_temporal_attn=True, **kwargs) + e_t = e_t + conditional_guidance_scale_temporal * (e_t_temporal - e_t_image) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + + if is_video: + size = (b, 1, 1, 1, 1) + else: + size = (b, 1, 1, 1) + a_t = torch.full(size, alphas[index], device=device) + a_prev = torch.full(size, alphas_prev[index], device=device) + sigma_t = torch.full(size, sigmas[index], device=device) + sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + if self.use_scale: + scale_arr = self.model.scale_arr if use_original_steps else self.ddim_scale_arr + scale_t = torch.full(size, scale_arr[index], device=device) + scale_arr_prev = self.model.scale_arr_prev if use_original_steps else self.ddim_scale_arr_prev + scale_t_prev = torch.full(size, scale_arr_prev[index], device=device) + pred_x0 /= scale_t + x_prev = a_prev.sqrt() * scale_t_prev * pred_x0 + dir_xt + noise + else: + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + + return x_prev, pred_x0 + + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + + def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec + diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/models/utils_diffusion.py b/src/videogen_hub/pipelines/videocrafter/lvdm/models/utils_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..603fa817b07cea3581a70ff225d479b7d1518463 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/models/utils_diffusion.py @@ -0,0 +1,104 @@ +import math +import numpy as np +from einops import repeat +import torch +import torch.nn.functional as F + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}') + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/modules/__init__.py b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/modules/attention.py b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..bceba7d2e56003fcc40ab7f9b14d8e0c33cc1638 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/attention.py @@ -0,0 +1,475 @@ +from functools import partial +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange, repeat +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False +from lvdm.common import ( + checkpoint, + exists, + default, +) +from lvdm.basics import ( + zero_module, +) + +class RelativePosition(nn.Module): + """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """ + + def __init__(self, num_units, max_relative_position): + super().__init__() + self.num_units = num_units + self.max_relative_position = max_relative_position + self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)) + nn.init.xavier_uniform_(self.embeddings_table) + + def forward(self, length_q, length_k): + device = self.embeddings_table.device + range_vec_q = torch.arange(length_q, device=device) + range_vec_k = torch.arange(length_k, device=device) + distance_mat = range_vec_k[None, :] - range_vec_q[:, None] + distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) + final_mat = distance_mat_clipped + self.max_relative_position + final_mat = final_mat.long() + embeddings = self.embeddings_table[final_mat] + return embeddings + + +class CrossAttention(nn.Module): + + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., + relative_position=False, temporal_length=None, img_cross_attention=False): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + self.dim_head = dim_head + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + self.image_cross_attention_scale = 1.0 + self.text_context_len = 77 + self.img_cross_attention = img_cross_attention + if self.img_cross_attention: + self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) + + self.relative_position = relative_position + if self.relative_position: + assert(temporal_length is not None) + self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) + self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) + else: + ## only used for spatial attention, while NOT for temporal attention + if XFORMERS_IS_AVAILBLE and temporal_length is None: + self.forward = self.efficient_forward + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + ## considering image token additionally + if context is not None and self.img_cross_attention: + context, context_img = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] + k = self.to_k(context) + v = self.to_v(context) + k_ip = self.to_k_ip(context_img) + v_ip = self.to_v_ip(context_img) + else: + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + if self.relative_position: + len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1] + k2 = self.relative_position_k(len_q, len_k) + sim2 = einsum('b t d, t s d -> b t s', q, k2) * self.scale # TODO check + sim += sim2 + del k + + if exists(mask): + ## feasible for causal attention mask only + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b i j -> (b h) i j', h=h) + sim.masked_fill_(~(mask>0.5), max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + out = torch.einsum('b i j, b j d -> b i d', sim, v) + if self.relative_position: + v2 = self.relative_position_v(len_q, len_v) + out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check + out += out2 + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + ## considering image token additionally + if context is not None and self.img_cross_attention: + k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_ip, v_ip)) + sim_ip = torch.einsum('b i d, b j d -> b i j', q, k_ip) * self.scale + del k_ip + sim_ip = sim_ip.softmax(dim=-1) + out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) + out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) + out = out + self.image_cross_attention_scale * out_ip + del q + + return self.to_out(out) + + def efficient_forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + + ## considering image token additionally + if context is not None and self.img_cross_attention: + context, context_img = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] + k = self.to_k(context) + v = self.to_v(context) + k_ip = self.to_k_ip(context_img) + v_ip = self.to_v_ip(context_img) + else: + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) + + ## considering image token additionally + if context is not None and self.img_cross_attention: + k_ip, v_ip = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (k_ip, v_ip), + ) + out_ip = xformers.ops.memory_efficient_attention(q, k_ip, v_ip, attn_bias=None, op=None) + out_ip = ( + out_ip.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + if context is not None and self.img_cross_attention: + out = out + self.image_cross_attention_scale * out_ip + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False, attention_cls=None, img_cross_attention=False): + super().__init__() + attn_cls = CrossAttention if attention_cls is None else attention_cls + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, + img_cross_attention=img_cross_attention) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, mask=None): + ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments + input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments + if context is not None: + input_tuple = (x, context) + if mask is not None: + forward_mask = partial(self._forward, mask=mask) + return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint) + if context is not None and mask is not None: + input_tuple = (x, context, mask) + return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint) + + def _forward(self, x, context=None, mask=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data in spatial axis. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, + use_checkpoint=True, disable_self_attn=False, use_linear=False, img_cross_attention=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + img_cross_attention=img_cross_attention, + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + else: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + + def forward(self, x, context=None): + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data in temporal axis. + First, reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, + use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, + relative_position=False, temporal_length=None): + super().__init__() + self.only_self_att = only_self_att + self.relative_position = relative_position + self.causal_attention = causal_attention + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + if relative_position: + assert(temporal_length is not None) + attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length) + else: + attention_cls = None + if self.causal_attention: + assert(temporal_length is not None) + self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length])) + + if self.only_self_att: + context_dim = None + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + attention_cls=attention_cls, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + else: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + b, c, t, h, w = x.shape + x_in = x + x = self.norm(x) + x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'bhw c t -> bhw t c').contiguous() + if self.use_linear: + x = self.proj_in(x) + + if self.causal_attention: + mask = self.mask.to(x.device) + mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w) + else: + mask = None + + if self.only_self_att: + ## note: if no context is given, cross-attention defaults to self-attention + for i, block in enumerate(self.transformer_blocks): + x = block(x, mask=mask) + x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() + context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous() + for i, block in enumerate(self.transformer_blocks): + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_j = repeat( + context[j], + 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous() + ## note: causal mask will not applied in cross-attention case + x[j] = block(x[j], context=context_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous() + + return x + x_in + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/modules/encoders/__init__.py b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/modules/encoders/condition.py b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/encoders/condition.py new file mode 100644 index 0000000000000000000000000000000000000000..243fa555079e45720611954950b9d343b6ff8234 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/encoders/condition.py @@ -0,0 +1,392 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +import kornia +import open_clip +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel +from lvdm.common import autocast +from utils import count_params + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, + freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class ClipImageEmbedder(nn.Module): + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=True, + ucg_rate=0. + ): + super().__init__() + from clip import load as load_clip + self.model, _ = load_clip(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # re-normalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x, no_dropout=False): + # x is assumed to be in range [-1,1] + out = self.model.encode_image(self.preprocess(x)) + out = out.to(x.dtype) + if self.ucg_rate > 0. and not no_dropout: + out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out + return out + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + # "pooled", + "last", + "penultimate" + ] + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu')) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + self.device = self.model.positional_embedding.device + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="pooled", antialias=True, ucg_rate=0.): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), + pretrained=version, ) + del model.transformer + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + if self.ucg_rate > 0. and not no_dropout: + z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z + return z + + def encode_with_vision_transformer(self, img): + img = self.preprocess(img) + x = self.model.visual(img) + return x + + def encode(self, text): + return self(text) + + + +class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", + freeze=True, layer="pooled", antialias=True): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), + pretrained=version, ) + del model.transformer + self.model = model + self.device = device + + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + def forward(self, image, no_dropout=False): + ## image: b c h w + z = self.encode_with_vision_transformer(image) + return z + + def encode_with_vision_transformer(self, x): + x = self.preprocess(x) + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.model.visual.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1) + x = self.model.visual.patchnorm_pre_ln(x) + x = self.model.visual.conv1(x) + else: + x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.model.visual.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.model.visual.patch_dropout(x) + x = self.model.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.model.visual.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + return x + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] \ No newline at end of file diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/modules/encoders/ip_resampler.py b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/encoders/ip_resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..500820a789150a55d6e8fdca4dd3e4d6ad542d4a --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/encoders/ip_resampler.py @@ -0,0 +1,136 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math +import torch +import torch.nn as nn + + +class ImageProjModel(nn.Module): + """Projection Model""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + #embeds = image_embeds + embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) + clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) \ No newline at end of file diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/modules/networks/__init__.py b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/modules/networks/ae_modules.py b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/networks/ae_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9aa0a9efa21a4df770513bbcd36feee0e4a81d3b --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/networks/ae_modules.py @@ -0,0 +1,845 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import numpy as np +import torch.nn as nn +from einops import rearrange +from utils import instantiate_from_config +from lvdm.modules.attention import LinearAttention + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) # bcl + q = q.permute(0,2,1) # bcl -> blc l=hw + k = k.reshape(b,c,h*w) # bcl + + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + #print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # print(f'encoder-input={x.shape}') + # downsampling + hs = [self.conv_in(x)] + # print(f'encoder-conv in feat={hs[0].shape}') + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + # print(f'encoder-down feat={h.shape}') + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + # print(f'encoder-downsample (input)={hs[-1].shape}') + hs.append(self.down[i_level].downsample(hs[-1])) + # print(f'encoder-downsample (output)={hs[-1].shape}') + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + # print(f'encoder-mid1 feat={h.shape}') + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # print(f'encoder-mid2 feat={h.shape}') + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # print(f'end feat={h.shape}') + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("AE working on z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # print(f'decoder-input={z.shape}') + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + # print(f'decoder-conv in feat={h.shape}') + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # print(f'decoder-mid feat={h.shape}') + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + # print(f'decoder-up feat={h.shape}') + if i_level != 0: + h = self.up[i_level].upsample(h) + # print(f'decoder-upsample feat={h.shape}') + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # print(f'decoder-conv_out feat={h.shape}') + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/modules/networks/openaimodel3d.py b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/networks/openaimodel3d.py new file mode 100644 index 0000000000000000000000000000000000000000..328fac71761a69b461f33946d6a1aa08622ecd8f --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/networks/openaimodel3d.py @@ -0,0 +1,577 @@ +from functools import partial +from abc import abstractmethod +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +from lvdm.models.utils_diffusion import timestep_embedding +from lvdm.common import checkpoint +from lvdm.basics import ( + zero_module, + conv_nd, + linear, + avg_pool_nd, + normalization +) +from lvdm.modules.attention import SpatialTransformer, TemporalTransformer + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, batch_size=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb, batch_size) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + elif isinstance(layer, TemporalTransformer): + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size) + x = layer(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + else: + x = layer(x,) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + use_conv=False, + up=False, + down=False, + use_temporal_conv=False, + tempspatial_aware=False + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock( + self.out_channels, + self.out_channels, + dropout=0.1, + spatial_aware=tempspatial_aware + ) + + def forward(self, x, emb, batch_size=None): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + input_tuple = (x, emb,) + if batch_size: + forward_batchsize = partial(self._forward, batch_size=batch_size) + return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint) + return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint) + + def _forward(self, x, emb, batch_size=None,): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv and batch_size: + h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c t h w -> (b t) c h w') + return h + + +class TemporalConvBlock(nn.Module): + """ + Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py + """ + + def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False): + super(TemporalConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + self.in_channels = in_channels + self.out_channels = out_channels + kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 3) + padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 1) + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_channels), nn.SiLU(), + nn.Conv3d(in_channels, out_channels, kernel_shape, padding=padding_shape)) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_channels, in_channels, kernel_shape, padding=padding_shape)) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_channels, in_channels, (3, 1, 1), padding=(1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_channels, in_channels, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + return x + identity + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: in_channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + """ + + def __init__(self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0.0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + context_dim=None, + use_scale_shift_norm=False, + resblock_updown=False, + num_heads=-1, + num_head_channels=-1, + transformer_depth=1, + use_linear=False, + use_checkpoint=False, + temporal_conv=False, + tempspatial_aware=False, + temporal_attention=True, + temporal_selfatt_only=True, + use_relative_position=True, + use_causal_attention=False, + temporal_length=None, + use_fp16=False, + addition_attention=False, + use_image_attention=False, + temporal_transformer_depth=1, + fps_cond=False, + ): + super(UNetModel, self).__init__() + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.temporal_attention = temporal_attention + time_embed_dim = model_channels * 4 + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.addition_attention=addition_attention + self.use_image_attention = use_image_attention + self.fps_cond=fps_cond + + + + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + if self.fps_cond: + self.fps_embedding = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1)) + ] + ) + if self.addition_attention: + self.init_attn=TimestepEmbedSequential( + TemporalTransformer( + model_channels, + n_heads=8, + d_head=num_head_channels, + depth=transformer_depth, + context_dim=context_dim, + use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only, + causal_attention=use_causal_attention, relative_position=use_relative_position, + temporal_length=temporal_length)) + + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock(ch, time_embed_dim, dropout, + out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers.append( + SpatialTransformer(ch, num_heads, dim_head, + depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, disable_self_attn=False, + img_cross_attention=self.use_image_attention + ) + ) + if self.temporal_attention: + layers.append( + TemporalTransformer(ch, num_heads, dim_head, + depth=temporal_transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only, + causal_attention=use_causal_attention, relative_position=use_relative_position, + temporal_length=temporal_length + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock(ch, time_embed_dim, dropout, + out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers = [ + ResBlock(ch, time_embed_dim, dropout, + dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv + ), + SpatialTransformer(ch, num_heads, dim_head, + depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, disable_self_attn=False, + img_cross_attention=self.use_image_attention + ) + ] + if self.temporal_attention: + layers.append( + TemporalTransformer(ch, num_heads, dim_head, + depth=temporal_transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only, + causal_attention=use_causal_attention, relative_position=use_relative_position, + temporal_length=temporal_length + ) + ) + layers.append( + ResBlock(ch, time_embed_dim, dropout, + dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv + ) + ) + self.middle_block = TimestepEmbedSequential(*layers) + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock(ch + ich, time_embed_dim, dropout, + out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers.append( + SpatialTransformer(ch, num_heads, dim_head, + depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, disable_self_attn=False, + img_cross_attention=self.use_image_attention + ) + ) + if self.temporal_attention: + layers.append( + TemporalTransformer(ch, num_heads, dim_head, + depth=temporal_transformer_depth, context_dim=context_dim, use_linear=use_linear, + use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only, + causal_attention=use_causal_attention, relative_position=use_relative_position, + temporal_length=temporal_length + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock(ch, time_embed_dim, dropout, + out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def forward(self, x, timesteps, context=None, features_adapter=None, fps=16, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.fps_cond: + if type(fps) == int: + fps = torch.full_like(timesteps, fps) + fps_emb = timestep_embedding(fps,self.model_channels, repeat_only=False) + emb += self.fps_embedding(fps_emb) + + b,_,t,_,_ = x.shape + ## repeat t times for context [(b t) 77 768] & time embedding + context = context.repeat_interleave(repeats=t, dim=0) + emb = emb.repeat_interleave(repeats=t, dim=0) + + ## always in shape (b t) c h w, except for temporal layer + x = rearrange(x, 'b c t h w -> (b t) c h w') + + h = x.type(self.dtype) + adapter_idx = 0 + hs = [] + for id, module in enumerate(self.input_blocks): + h = module(h, emb, context=context, batch_size=b) + if id ==0 and self.addition_attention: + h = self.init_attn(h, emb, context=context, batch_size=b) + ## plug-in adapter features + if ((id+1)%3 == 0) and features_adapter is not None: + h = h + features_adapter[adapter_idx] + adapter_idx += 1 + hs.append(h) + if features_adapter is not None: + assert len(features_adapter)==adapter_idx, 'Wrong features_adapter' + + h = self.middle_block(h, emb, context=context, batch_size=b) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context=context, batch_size=b) + h = h.type(x.dtype) + y = self.out(h) + + # reshape back to (b c t h w) + y = rearrange(y, '(b t) c h w -> b c t h w', b=b) + return y + \ No newline at end of file diff --git a/src/videogen_hub/pipelines/videocrafter/lvdm/modules/x_transformer.py b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f252ab4032a78407ed487495807940c4ba802ffa --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/lvdm/modules/x_transformer.py @@ -0,0 +1,640 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat +import torch +from torch import nn, einsum +import torch.nn.functional as F + +# constants +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/src/videogen_hub/pipelines/videocrafter/utils.py b/src/videogen_hub/pipelines/videocrafter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b0fb43d86059661d1b3abeb6154eefef316f19 --- /dev/null +++ b/src/videogen_hub/pipelines/videocrafter/utils.py @@ -0,0 +1,77 @@ +import importlib +import numpy as np +import cv2 +import torch +import torch.distributed as dist +import os + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def check_istarget(name, para_list): + """ + name: full name of source para + para_list: partial name of target para + """ + istarget=False + for para in para_list: + if para in name: + return True + return istarget + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def load_npz_from_dir(data_dir): + data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] + data = np.concatenate(data, axis=0) + return data + + +def load_npz_from_paths(data_paths): + data = [np.load(data_path)['arr_0'] for data_path in data_paths] + data = np.concatenate(data, axis=0) + return data + + +def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): + h, w = image.shape[:2] + if resize_short_edge is not None: + k = resize_short_edge / min(h, w) + else: + k = max_resolution / (h * w) + k = k**0.5 + h = int(np.round(h * k / 64)) * 64 + w = int(np.round(w * k / 64)) * 64 + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) + return image + + +def setup_dist(args): + if dist.is_initialized(): + return + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group( + 'nccl', + init_method='env://' + ) \ No newline at end of file